From 5a3d4e1844969bbcc1a74bff92bbe7f25740ffb9 Mon Sep 17 00:00:00 2001 From: Gijs Molenaar Date: Wed, 25 Mar 2026 10:51:58 +0200 Subject: [PATCH 1/4] Add unified s7 package with protocol auto-discovery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a new top-level `s7` package that provides a single Client, AsyncClient, and Server that automatically select the best protocol (S7CommPlus or legacy S7) for the target PLC. Users no longer need to choose between snap7.Client and snap7.s7commplus.Client. - `from s7 import Client` — unified entry point - connect() tries S7CommPlus first, falls back to legacy transparently - Full legacy API (~100 methods) available via __getattr__ delegation - db_read/db_write route through S7CommPlus when available - Unified Server wraps both legacy and S7CommPlus servers - 26 unit tests covering auto-discovery, fallback, and delegation Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 3 +- s7/__init__.py | 35 ++++ s7/_protocol.py | 19 ++ s7/async_client.py | 246 ++++++++++++++++++++++ s7/client.py | 314 +++++++++++++++++++++++++++ s7/py.typed | 0 s7/server.py | 181 ++++++++++++++++ tests/test_s7_unified.py | 443 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 1240 insertions(+), 1 deletion(-) create mode 100644 s7/__init__.py create mode 100644 s7/_protocol.py create mode 100644 s7/async_client.py create mode 100644 s7/client.py create mode 100644 s7/py.typed create mode 100644 s7/server.py create mode 100644 tests/test_s7_unified.py diff --git a/pyproject.toml b/pyproject.toml index b865de58..ca23f81a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,9 +40,10 @@ discovery = ["pnio-dcp"] [tool.setuptools.package-data] snap7 = ["py.typed"] +s7 = ["py.typed"] [tool.setuptools.packages.find] -include = ["snap7*"] +include = ["snap7*", "s7*"] [project.scripts] snap7-server = "snap7.server:mainloop" diff --git a/s7/__init__.py b/s7/__init__.py new file mode 100644 index 00000000..1cfaa189 --- /dev/null +++ b/s7/__init__.py @@ -0,0 +1,35 @@ +"""Unified S7 communication library. + +Provides protocol-agnostic access to Siemens S7 PLCs with automatic +protocol discovery (S7CommPlus vs legacy S7). + +Usage:: + + from s7 import Client + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) +""" + +from .client import Client +from .async_client import AsyncClient +from .server import Server +from ._protocol import Protocol + +from snap7.type import Area, Block, WordLen, SrvEvent, SrvArea +from snap7.util.db import Row, DB + +__all__ = [ + "Client", + "AsyncClient", + "Server", + "Protocol", + "Area", + "Block", + "WordLen", + "SrvEvent", + "SrvArea", + "Row", + "DB", +] diff --git a/s7/_protocol.py b/s7/_protocol.py new file mode 100644 index 00000000..f1ca0328 --- /dev/null +++ b/s7/_protocol.py @@ -0,0 +1,19 @@ +"""Protocol enum for unified S7 client.""" + +from enum import Enum + + +class Protocol(Enum): + """S7 communication protocol selection. + + Used to control protocol auto-discovery in the unified client. + + Attributes: + AUTO: Try S7CommPlus first, fall back to legacy S7 if unsupported. + LEGACY: Use legacy S7 protocol only (S7-300/400, basic S7-1200/1500). + S7COMMPLUS: Use S7CommPlus protocol only (S7-1200/1500 with full access). + """ + + AUTO = "auto" + LEGACY = "legacy" + S7COMMPLUS = "s7commplus" diff --git a/s7/async_client.py b/s7/async_client.py new file mode 100644 index 00000000..df558f1f --- /dev/null +++ b/s7/async_client.py @@ -0,0 +1,246 @@ +"""Unified async S7 client with protocol auto-discovery. + +Provides a single async client that automatically selects the best protocol +(S7CommPlus or legacy S7) for communicating with Siemens S7 PLCs. + +Usage:: + + from s7 import AsyncClient + + async with AsyncClient() as client: + await client.connect("192.168.1.10", 0, 1) + data = await client.db_read(1, 0, 4) +""" + +import logging +from typing import Any, Optional + +from snap7.async_client import AsyncClient as LegacyAsyncClient +from snap7.s7commplus.async_client import S7CommPlusAsyncClient + +from ._protocol import Protocol + +logger = logging.getLogger(__name__) + + +class AsyncClient: + """Unified async S7 client with protocol auto-discovery. + + Async counterpart of :class:`s7.Client`. Automatically selects the + best protocol for the target PLC using asyncio for non-blocking I/O. + + Methods not explicitly defined are delegated to the underlying + legacy async client via ``__getattr__``. + + Examples:: + + from s7 import AsyncClient + + async with AsyncClient() as client: + await client.connect("192.168.1.10", 0, 1) + data = await client.db_read(1, 0, 4) + print(client.protocol) + """ + + def __init__(self) -> None: + self._legacy: Optional[LegacyAsyncClient] = None + self._plus: Optional[S7CommPlusAsyncClient] = None + self._protocol: Protocol = Protocol.AUTO + self._host: str = "" + self._port: int = 102 + self._rack: int = 0 + self._slot: int = 1 + + @property + def protocol(self) -> Protocol: + """The protocol currently in use for DB operations.""" + return self._protocol + + @property + def connected(self) -> bool: + """Whether the client is connected to a PLC.""" + if self._legacy is not None and self._legacy.connected: + return True + if self._plus is not None and self._plus.connected: + return True + return False + + async def connect( + self, + address: str, + rack: int = 0, + slot: int = 1, + tcp_port: int = 102, + *, + protocol: Protocol = Protocol.AUTO, + ) -> "AsyncClient": + """Connect to an S7 PLC. + + Args: + address: PLC IP address or hostname. + rack: PLC rack number. + slot: PLC slot number. + tcp_port: TCP port (default 102). + protocol: Protocol selection. AUTO tries S7CommPlus first, + then falls back to legacy S7. + + Returns: + self, for method chaining. + """ + self._host = address + self._port = tcp_port + self._rack = rack + self._slot = slot + + if protocol in (Protocol.AUTO, Protocol.S7COMMPLUS): + if await self._try_s7commplus(address, tcp_port, rack, slot): + self._protocol = Protocol.S7COMMPLUS + logger.info(f"Async connected to {address}:{tcp_port} using S7CommPlus") + else: + if protocol == Protocol.S7COMMPLUS: + raise RuntimeError( + f"S7CommPlus connection to {address}:{tcp_port} failed and protocol=S7COMMPLUS was explicitly requested" + ) + self._protocol = Protocol.LEGACY + logger.info(f"S7CommPlus not available, using legacy S7 for {address}:{tcp_port}") + else: + self._protocol = Protocol.LEGACY + + # Always connect legacy client + self._legacy = LegacyAsyncClient() + await self._legacy.connect(address, rack, slot, tcp_port) + logger.info(f"Async legacy S7 connected to {address}:{tcp_port}") + + return self + + async def _try_s7commplus( + self, + address: str, + tcp_port: int, + rack: int, + slot: int, + ) -> bool: + """Attempt async S7CommPlus connection and probe data operations.""" + plus = S7CommPlusAsyncClient() + try: + await plus.connect(host=address, port=tcp_port, rack=rack, slot=slot) + except Exception as e: + logger.debug(f"Async S7CommPlus connection failed: {e}") + return False + + if plus.using_legacy_fallback: + logger.debug("S7CommPlus connected but data ops not supported, disconnecting") + try: + await plus.disconnect() + except Exception: + pass + return False + + self._plus = plus + return True + + async def disconnect(self) -> int: + """Disconnect from the PLC. + + Returns: + 0 on success. + """ + if self._plus is not None: + try: + await self._plus.disconnect() + except Exception: + pass + self._plus = None + + if self._legacy is not None: + try: + await self._legacy.disconnect() + except Exception: + pass + self._legacy = None + + self._protocol = Protocol.AUTO + return 0 + + async def db_read(self, db_number: int, start: int, size: int) -> bytearray: + """Read data from a data block. + + Args: + db_number: DB number to read from. + start: Start byte offset. + size: Number of bytes to read. + + Returns: + Data read from the DB. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return bytearray(await self._plus.db_read(db_number, start, size)) + if self._legacy is not None: + return await self._legacy.db_read(db_number, start, size) + raise RuntimeError("Not connected") + + async def db_write(self, db_number: int, start: int, data: bytearray | bytes) -> int: + """Write data to a data block. + + Args: + db_number: DB number to write to. + start: Start byte offset. + data: Data to write. + + Returns: + 0 on success. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + await self._plus.db_write(db_number, start, bytes(data)) + return 0 + if self._legacy is not None: + return await self._legacy.db_write(db_number, start, bytearray(data)) + raise RuntimeError("Not connected") + + async def db_read_multi(self, items: list[tuple[int, int, int]]) -> list[bytearray]: + """Read multiple data block regions. + + Args: + items: List of (db_number, start_offset, size) tuples. + + Returns: + List of data for each item. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return [bytearray(r) for r in await self._plus.db_read_multi(items)] + if self._legacy is not None: + results = [] + for db, start, size in items: + results.append(await self._legacy.db_read(db, start, size)) + return results + raise RuntimeError("Not connected") + + async def explore(self) -> bytes: + """Browse the PLC object tree (S7CommPlus only). + + Returns: + Raw response payload. + """ + if self._plus is not None: + return await self._plus.explore() + raise RuntimeError("explore() requires S7CommPlus protocol") + + def __getattr__(self, name: str) -> Any: + """Delegate unknown attributes to the legacy async client.""" + if name.startswith("_"): + raise AttributeError(name) + legacy = self.__dict__.get("_legacy") + if legacy is not None: + return getattr(legacy, name) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + async def __aenter__(self) -> "AsyncClient": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.disconnect() + + def __repr__(self) -> str: + status = "connected" if self.connected else "disconnected" + proto = self._protocol.value if self.connected else "none" + return f"" diff --git a/s7/client.py b/s7/client.py new file mode 100644 index 00000000..8e143895 --- /dev/null +++ b/s7/client.py @@ -0,0 +1,314 @@ +"""Unified S7 client with protocol auto-discovery. + +Provides a single client that automatically selects the best protocol +(S7CommPlus or legacy S7) for communicating with Siemens S7 PLCs. + +Usage:: + + from s7 import Client + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + client.disconnect() +""" + +import logging +from typing import Any, Optional + +from snap7.client import Client as LegacyClient +from snap7.s7commplus.client import S7CommPlusClient + +from ._protocol import Protocol + +logger = logging.getLogger(__name__) + + +class Client: + """Unified S7 client with protocol auto-discovery. + + Automatically selects the best protocol for the target PLC: + + - **S7CommPlus**: For S7-1200/1500 PLCs with full engineering access + - **Legacy S7**: For S7-300/400 and S7-1200/1500 without S7CommPlus support + + When ``protocol=Protocol.AUTO`` (default), the client tries S7CommPlus + first and falls back to legacy S7 transparently. + + Exposes the full legacy S7 client API. Methods not explicitly defined + (block operations, PLC control, memory areas, etc.) are delegated to + the underlying legacy client via ``__getattr__``. + + Examples:: + + from s7 import Client + + # Auto-discover protocol + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + print(client.protocol) # Protocol.LEGACY or Protocol.S7COMMPLUS + + # Force legacy protocol + client = Client() + client.connect("192.168.1.10", 0, 1, protocol=Protocol.LEGACY) + + # S7CommPlus with TLS + client = Client() + client.connect("192.168.1.10", 0, 1, use_tls=True, password="secret") + """ + + def __init__(self) -> None: + self._legacy: Optional[LegacyClient] = None + self._plus: Optional[S7CommPlusClient] = None + self._protocol: Protocol = Protocol.AUTO + self._host: str = "" + self._port: int = 102 + self._rack: int = 0 + self._slot: int = 1 + + @property + def protocol(self) -> Protocol: + """The protocol currently in use for DB operations.""" + return self._protocol + + @property + def connected(self) -> bool: + """Whether the client is connected to a PLC.""" + if self._legacy is not None and self._legacy.connected: + return True + if self._plus is not None and self._plus.connected: + return True + return False + + def connect( + self, + address: str, + rack: int = 0, + slot: int = 1, + tcp_port: int = 102, + *, + protocol: Protocol = Protocol.AUTO, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + password: Optional[str] = None, + ) -> "Client": + """Connect to an S7 PLC. + + Args: + address: PLC IP address or hostname. + rack: PLC rack number. + slot: PLC slot number. + tcp_port: TCP port (default 102). + protocol: Protocol selection. AUTO tries S7CommPlus first, + then falls back to legacy S7. + use_tls: Enable TLS for S7CommPlus V2+ connections. + tls_cert: Path to client TLS certificate (PEM). + tls_key: Path to client private key (PEM). + tls_ca: Path to CA certificate for PLC verification (PEM). + password: PLC password for S7CommPlus legitimation. + + Returns: + self, for method chaining. + """ + self._host = address + self._port = tcp_port + self._rack = rack + self._slot = slot + + if protocol in (Protocol.AUTO, Protocol.S7COMMPLUS): + if self._try_s7commplus( + address, + tcp_port, + rack, + slot, + use_tls=use_tls, + tls_cert=tls_cert, + tls_key=tls_key, + tls_ca=tls_ca, + password=password, + ): + self._protocol = Protocol.S7COMMPLUS + logger.info(f"Connected to {address}:{tcp_port} using S7CommPlus") + else: + if protocol == Protocol.S7COMMPLUS: + raise RuntimeError( + f"S7CommPlus connection to {address}:{tcp_port} failed and protocol=S7COMMPLUS was explicitly requested" + ) + self._protocol = Protocol.LEGACY + logger.info(f"S7CommPlus not available, using legacy S7 for {address}:{tcp_port}") + else: + self._protocol = Protocol.LEGACY + + # Always connect legacy client (needed for block ops, PLC control, etc.) + self._legacy = LegacyClient() + self._legacy.connect(address, rack, slot, tcp_port) + logger.info(f"Legacy S7 connected to {address}:{tcp_port}") + + return self + + def _try_s7commplus( + self, + address: str, + tcp_port: int, + rack: int, + slot: int, + *, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + password: Optional[str] = None, + ) -> bool: + """Attempt S7CommPlus connection and probe data operations. + + Returns True if S7CommPlus data operations are fully supported. + On failure, cleans up the S7CommPlus connection. + """ + plus = S7CommPlusClient() + try: + plus.connect( + host=address, + port=tcp_port, + rack=rack, + slot=slot, + use_tls=use_tls, + tls_cert=tls_cert, + tls_key=tls_key, + tls_ca=tls_ca, + password=password, + ) + except Exception as e: + logger.debug(f"S7CommPlus connection failed: {e}") + return False + + # The S7CommPlus client probes data ops internally and sets + # using_legacy_fallback if they don't work. We don't want to use + # its built-in fallback (that would create a duplicate legacy + # connection), so we check and disconnect if it fell back. + if plus.using_legacy_fallback: + logger.debug("S7CommPlus connected but data ops not supported, disconnecting") + try: + plus.disconnect() + except Exception: + pass + return False + + self._plus = plus + return True + + def disconnect(self) -> int: + """Disconnect from the PLC. + + Returns: + 0 on success. + """ + if self._plus is not None: + try: + self._plus.disconnect() + except Exception: + pass + self._plus = None + + if self._legacy is not None: + try: + self._legacy.disconnect() + except Exception: + pass + self._legacy = None + + self._protocol = Protocol.AUTO + return 0 + + def db_read(self, db_number: int, start: int, size: int) -> bytearray: + """Read data from a data block. + + Uses S7CommPlus when available, otherwise legacy S7. + + Args: + db_number: DB number to read from. + start: Start byte offset. + size: Number of bytes to read. + + Returns: + Data read from the DB. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return bytearray(self._plus.db_read(db_number, start, size)) + if self._legacy is not None: + return self._legacy.db_read(db_number, start, size) + raise RuntimeError("Not connected") + + def db_write(self, db_number: int, start: int, data: bytearray | bytes) -> int: + """Write data to a data block. + + Uses S7CommPlus when available, otherwise legacy S7. + + Args: + db_number: DB number to write to. + start: Start byte offset. + data: Data to write. + + Returns: + 0 on success. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + self._plus.db_write(db_number, start, bytes(data)) + return 0 + if self._legacy is not None: + return self._legacy.db_write(db_number, start, bytearray(data)) + raise RuntimeError("Not connected") + + def db_read_multi(self, items: list[tuple[int, int, int]]) -> list[bytearray]: + """Read multiple data block regions. + + Uses S7CommPlus native multi-read when available, otherwise + performs individual reads via legacy S7. + + Args: + items: List of (db_number, start_offset, size) tuples. + + Returns: + List of data for each item. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return [bytearray(r) for r in self._plus.db_read_multi(items)] + if self._legacy is not None: + return [self._legacy.db_read(db, start, size) for db, start, size in items] + raise RuntimeError("Not connected") + + def explore(self) -> bytes: + """Browse the PLC object tree (S7CommPlus only). + + Returns: + Raw response payload. + + Raises: + RuntimeError: If S7CommPlus is not active. + """ + if self._plus is not None: + return self._plus.explore() + raise RuntimeError("explore() requires S7CommPlus protocol") + + def __getattr__(self, name: str) -> Any: + """Delegate unknown attributes to the legacy client.""" + # Avoid infinite recursion for attributes accessed during __init__ + if name.startswith("_"): + raise AttributeError(name) + legacy = self.__dict__.get("_legacy") + if legacy is not None: + return getattr(legacy, name) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + def __enter__(self) -> "Client": + return self + + def __exit__(self, *args: Any) -> None: + self.disconnect() + + def __repr__(self) -> str: + status = "connected" if self.connected else "disconnected" + proto = self._protocol.value if self.connected else "none" + return f"" diff --git a/s7/py.typed b/s7/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/s7/server.py b/s7/server.py new file mode 100644 index 00000000..e5ec4341 --- /dev/null +++ b/s7/server.py @@ -0,0 +1,181 @@ +"""Unified S7 server combining legacy S7 and S7CommPlus protocols. + +Provides a single server that can handle both legacy S7 and S7CommPlus +client connections simultaneously. + +Usage:: + + from s7 import Server + + server = Server() + server.register_area(SrvArea.DB, 1, db_data) + server.start(tcp_port=102) +""" + +import logging +from typing import Any, Optional + +from snap7.server import Server as LegacyServer +from snap7.s7commplus.server import S7CommPlusServer, DataBlock + +logger = logging.getLogger(__name__) + + +class Server: + """Unified S7 server combining legacy S7 and S7CommPlus. + + Wraps both a legacy S7 server and an S7CommPlus server, allowing + them to run side by side. Legacy clients connect on the primary port, + S7CommPlus clients on an optional secondary port. + + Methods not explicitly defined are delegated to the underlying + legacy server via ``__getattr__``. + + Examples:: + + from s7 import Server + from snap7.type import SrvArea + + server = Server() + + # Register memory areas for legacy S7 clients + db_data = bytearray(100) + server.register_area(SrvArea.DB, 1, db_data) + + # Register data blocks for S7CommPlus clients + server.register_db(1, {"temperature": ("Real", 0)}) + + # Start both servers + server.start(tcp_port=102, s7commplus_port=11020) + + # Stop both + server.stop() + """ + + def __init__(self, log: bool = True) -> None: + """Initialize unified server. + + Args: + log: Enable event logging for the legacy server. + """ + self._legacy: LegacyServer = LegacyServer(log=log) + self._plus: S7CommPlusServer = S7CommPlusServer() + self._plus_running: bool = False + + def start( + self, + tcp_port: int = 102, + s7commplus_port: Optional[int] = None, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + ) -> int: + """Start the server(s). + + Args: + tcp_port: Port for legacy S7 clients (default 102). + s7commplus_port: Optional port for S7CommPlus clients. + If None, only the legacy server is started. + use_tls: Enable TLS for S7CommPlus server. + tls_cert: Path to TLS certificate (PEM). + tls_key: Path to TLS private key (PEM). + + Returns: + 0 on success. + """ + result = self._legacy.start(tcp_port=tcp_port) + + if s7commplus_port is not None: + self._plus.start( + port=s7commplus_port, + use_tls=use_tls, + tls_cert=tls_cert, + tls_key=tls_key, + ) + self._plus_running = True + logger.info(f"S7CommPlus server started on port {s7commplus_port}") + + return result + + def stop(self) -> int: + """Stop both servers. + + Returns: + 0 on success. + """ + if self._plus_running: + try: + self._plus.stop() + except Exception: + pass + self._plus_running = False + + return self._legacy.stop() + + # -- S7CommPlus server methods -- + + def register_db( + self, + db_number: int, + variables: dict[str, tuple[str, int]], + size: int = 1024, + ) -> DataBlock: + """Register a data block on the S7CommPlus server. + + Args: + db_number: Data block number. + variables: Dict of {name: (type_name, byte_offset)}. + size: DB size in bytes (default 1024). + + Returns: + The registered DataBlock. + """ + return self._plus.register_db(db_number, variables, size) + + def register_raw_db(self, db_number: int, data: bytearray) -> DataBlock: + """Register a raw data block on the S7CommPlus server. + + Args: + db_number: Data block number. + data: Raw DB data. + + Returns: + The registered DataBlock. + """ + return self._plus.register_raw_db(db_number, data) + + def get_db(self, db_number: int) -> Optional[DataBlock]: + """Get a registered S7CommPlus data block. + + Args: + db_number: Data block number. + + Returns: + The DataBlock, or None if not registered. + """ + return self._plus.get_db(db_number) + + @property + def s7commplus_server(self) -> S7CommPlusServer: + """Direct access to the underlying S7CommPlus server.""" + return self._plus + + @property + def legacy_server(self) -> LegacyServer: + """Direct access to the underlying legacy S7 server.""" + return self._legacy + + def __getattr__(self, name: str) -> Any: + """Delegate unknown attributes to the legacy server.""" + if name.startswith("_"): + raise AttributeError(name) + legacy = self.__dict__.get("_legacy") + if legacy is not None: + return getattr(legacy, name) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + def __enter__(self) -> "Server": + return self + + def __exit__(self, *args: Any) -> None: + self.stop() diff --git a/tests/test_s7_unified.py b/tests/test_s7_unified.py new file mode 100644 index 00000000..a58e7376 --- /dev/null +++ b/tests/test_s7_unified.py @@ -0,0 +1,443 @@ +"""Tests for the unified s7 package.""" + +import pytest +from unittest.mock import MagicMock, patch + +from s7 import Client, Server, Protocol, Area, Block, WordLen + + +class TestProtocol: + """Test Protocol enum.""" + + def test_enum_values(self) -> None: + assert Protocol.AUTO.value == "auto" + assert Protocol.LEGACY.value == "legacy" + assert Protocol.S7COMMPLUS.value == "s7commplus" + + +class TestClientInit: + """Test Client initialization.""" + + def test_default_state(self) -> None: + client = Client() + assert client.protocol == Protocol.AUTO + assert client.connected is False + + def test_repr_disconnected(self) -> None: + client = Client() + assert "disconnected" in repr(client) + + +class TestClientLegacy: + """Test Client with legacy protocol (mocked).""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_connect_legacy_fallback(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """When S7CommPlus fails, client falls back to legacy.""" + # S7CommPlus connect raises + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("Connection refused") + mock_plus_cls.return_value = mock_plus + + # Legacy connects fine + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + result = client.connect("192.168.1.10", 0, 1) + + assert result is client + assert client.protocol == Protocol.LEGACY + mock_legacy.connect.assert_called_once_with("192.168.1.10", 0, 1, 102) + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_connect_explicit_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """When protocol=LEGACY, S7CommPlus is not attempted.""" + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1, protocol=Protocol.LEGACY) + + assert client.protocol == Protocol.LEGACY + mock_plus_cls.assert_not_called() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_read_delegates_to_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """db_read delegates to legacy when protocol is LEGACY.""" + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.db_read.return_value = bytearray([1, 2, 3, 4]) + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + + assert data == bytearray([1, 2, 3, 4]) + mock_legacy.db_read.assert_called_once_with(1, 0, 4) + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_write_delegates_to_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.db_write.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + result = client.db_write(1, 0, bytearray([1, 2, 3, 4])) + + assert result == 0 + mock_legacy.db_write.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_getattr_delegates_to_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """Methods not on unified client delegate to legacy.""" + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.get_cpu_info.return_value = "cpu_info" + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + info = client.get_cpu_info() + + assert info == "cpu_info" + mock_legacy.get_cpu_info.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_disconnect(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + result = client.disconnect() + + assert result == 0 + mock_legacy.disconnect.assert_called_once() + + +class TestClientS7CommPlus: + """Test Client with S7CommPlus protocol (mocked).""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_connect_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """When S7CommPlus succeeds, protocol is S7COMMPLUS.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + + assert client.protocol == Protocol.S7COMMPLUS + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_read_uses_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """db_read uses S7CommPlus when available.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus.db_read.return_value = b"\x01\x02\x03\x04" + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + + assert data == bytearray([1, 2, 3, 4]) + assert isinstance(data, bytearray) + mock_plus.db_read.assert_called_once_with(1, 0, 4) + mock_legacy.db_read.assert_not_called() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_write_uses_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + result = client.db_write(1, 0, bytearray([1, 2, 3, 4])) + + assert result == 0 + mock_plus.db_write.assert_called_once_with(1, 0, b"\x01\x02\x03\x04") + mock_legacy.db_write.assert_not_called() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_legacy_methods_still_delegate(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """Even with S7CommPlus, legacy-only methods go to legacy client.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.list_blocks.return_value = "blocks" + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + blocks = client.list_blocks() + + assert blocks == "blocks" + mock_legacy.list_blocks.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_s7commplus_fallback_detected(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """If S7CommPlus connects but falls back, unified client uses LEGACY.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = True # Data ops not supported + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + + assert client.protocol == Protocol.LEGACY + # S7CommPlus should have been disconnected + mock_plus.disconnect.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_explore_requires_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """explore() raises when S7CommPlus is not active.""" + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + + with pytest.raises(RuntimeError, match="S7CommPlus"): + client.explore() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_explicit_s7commplus_fails(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """protocol=S7COMMPLUS raises if S7CommPlus fails.""" + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + with pytest.raises(RuntimeError, match="explicitly requested"): + client.connect("192.168.1.10", 0, 1, protocol=Protocol.S7COMMPLUS) + + +class TestClientContextManager: + """Test context manager protocol.""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_context_manager(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + with Client() as client: + client.connect("192.168.1.10", 0, 1) + assert client.connected + + mock_legacy.disconnect.assert_called_once() + + +class TestClientDbReadMulti: + """Test db_read_multi.""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_multi_read_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus.db_read_multi.return_value = [b"\x01\x02", b"\x03\x04"] + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + results = client.db_read_multi([(1, 0, 2), (1, 2, 2)]) + + assert len(results) == 2 + assert all(isinstance(r, bytearray) for r in results) + mock_plus.db_read_multi.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_multi_read_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.db_read.side_effect = [bytearray([1, 2]), bytearray([3, 4])] + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + results = client.db_read_multi([(1, 0, 2), (1, 2, 2)]) + + assert len(results) == 2 + assert mock_legacy.db_read.call_count == 2 + + +class TestImports: + """Test that s7 package exports are accessible.""" + + def test_types_exported(self) -> None: + assert Area is not None + assert Block is not None + assert WordLen is not None + + def test_protocol_exported(self) -> None: + assert Protocol.AUTO is not None + + +class TestServer: + """Test unified Server (mocked).""" + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_start_legacy_only(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + result = server.start(tcp_port=11020) + + assert result == 0 + mock_legacy.start.assert_called_once_with(tcp_port=11020) + mock_plus.start.assert_not_called() + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_start_both(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + server.start(tcp_port=11020, s7commplus_port=11021) + + mock_legacy.start.assert_called_once_with(tcp_port=11020) + mock_plus.start.assert_called_once_with(port=11021, use_tls=False, tls_cert=None, tls_key=None) + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_stop_both(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy.stop.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + server.start(tcp_port=11020, s7commplus_port=11021) + result = server.stop() + + assert result == 0 + mock_plus.stop.assert_called_once() + mock_legacy.stop.assert_called_once() + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_getattr_delegates(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.get_status.return_value = ("running", "ok", 0) + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + status = server.get_status() + + assert status == ("running", "ok", 0) + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_context_manager(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy.stop.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + with Server() as server: + server.start(tcp_port=11020) + + mock_legacy.stop.assert_called_once() From 60be2853e2fa2560389ecad65690b49c517c14bd Mon Sep 17 00:00:00 2001 From: Gijs Molenaar Date: Wed, 25 Mar 2026 14:15:10 +0200 Subject: [PATCH 2/4] Implement S7CommPlus V2/V3 protocol with TLS and IntegrityId Add complete V2 and V3 support including TLS 1.3 encryption, per-request IntegrityId counters, session setup handshake, and password authentication (legitimation). Extend server emulator to support V2/V3 for integration testing without real hardware. - Fix V3 handling in connection.py (unified V2/V3 code path) - Add TLS, session setup, and auth to async client - Add legitimation and IntegrityId emulation to server - Add 21 TLS integration tests (sync + async, V2 + V3, auth) Co-Authored-By: Claude Opus 4.6 --- snap7/s7commplus/async_client.py | 408 +++++++++++++++++++++++++++++- snap7/s7commplus/client.py | 14 +- snap7/s7commplus/connection.py | 37 +-- snap7/s7commplus/server.py | 190 ++++++++++++-- tests/test_s7commplus_tls.py | 417 +++++++++++++++++++++++++++++++ 5 files changed, 1023 insertions(+), 43 deletions(-) create mode 100644 tests/test_s7commplus_tls.py diff --git a/snap7/s7commplus/async_client.py b/snap7/s7commplus/async_client.py index e6a46fe2..e58e9fc4 100644 --- a/snap7/s7commplus/async_client.py +++ b/snap7/s7commplus/async_client.py @@ -4,6 +4,10 @@ Provides the same API as S7CommPlusClient but using asyncio for non-blocking I/O. Uses asyncio.Lock for concurrent safety. +Supports all S7CommPlus protocol versions (V1/V2/V3/TLS). The protocol +version is auto-detected from the PLC's CreateObject response during +connection setup. + When a PLC does not support S7CommPlus data operations, the client transparently falls back to the legacy S7 protocol for data block read/write operations (using synchronous calls in an executor). @@ -18,6 +22,7 @@ import asyncio import logging +import ssl import struct from typing import Any, Optional @@ -25,6 +30,7 @@ DataType, ElementID, FunctionCode, + LegitimationId, ObjectId, Opcode, ProtocolVersion, @@ -35,6 +41,7 @@ from .codec import encode_header, decode_header, encode_typed_value, encode_object_qualifier from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .client import _build_read_payload, _parse_read_response, _build_write_payload, _parse_write_response +from .connection import _element_size logger = logging.getLogger(__name__) @@ -47,7 +54,7 @@ class S7CommPlusAsyncClient: """Async S7CommPlus client for S7-1200/1500 PLCs. - Supports V1 and V2 protocols. V3/TLS planned for future. + Supports V1, V2, and V3 protocols (including TLS). Uses asyncio for all I/O operations and asyncio.Lock for concurrent safety when shared between multiple coroutines. @@ -76,6 +83,13 @@ def __init__(self) -> None: self._integrity_id_write: int = 0 self._with_integrity_id: bool = False + # TLS state + self._tls_active: bool = False + self._oms_secret: Optional[bytes] = None + + # Session setup + self._server_session_version: Optional[int] = None + @property def connected(self) -> bool: if self._use_legacy_data and self._legacy_client is not None: @@ -95,12 +109,22 @@ def using_legacy_fallback(self) -> bool: """Whether the client is using legacy S7 protocol for data operations.""" return self._use_legacy_data + @property + def tls_active(self) -> bool: + """Whether TLS encryption is active on this connection.""" + return self._tls_active + async def connect( self, host: str, port: int = 102, rack: int = 0, slot: int = 1, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + password: Optional[str] = None, ) -> None: """Connect to an S7-1200/1500 PLC. @@ -112,6 +136,11 @@ async def connect( port: TCP port (default 102) rack: PLC rack number slot: PLC slot number + use_tls: Whether to activate TLS (required for V2/V3) + tls_cert: Path to client TLS certificate (PEM) + tls_key: Path to client private key (PEM) + tls_ca: Path to CA certificate for PLC verification (PEM) + password: PLC password for legitimation (V2+ with TLS) """ self._host = host self._port = port @@ -128,14 +157,43 @@ async def connect( # InitSSL handshake await self._init_ssl() + # TLS activation (between InitSSL and CreateObject) + if use_tls: + await self._activate_tls(tls_cert=tls_cert, tls_key=tls_key, tls_ca=tls_ca) + # S7CommPlus session setup await self._create_session() + # Echo ServerSessionVersion back to complete handshake + if self._server_session_version is not None: + await self._setup_session() + else: + logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") + + # Version-specific post-setup + if self._protocol_version >= ProtocolVersion.V2: + if not self._tls_active: + raise RuntimeError( + f"PLC reports V{self._protocol_version} protocol but TLS is not active. " + "V2/V3 requires TLS. Use use_tls=True." + ) + self._with_integrity_id = True + self._integrity_id_read = 0 + self._integrity_id_write = 0 + logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") + self._connected = True logger.info( - f"Async S7CommPlus connected to {host}:{port}, version=V{self._protocol_version}, session={self._session_id}" + f"Async S7CommPlus connected to {host}:{port}, " + f"version=V{self._protocol_version}, session={self._session_id}, " + f"tls={self._tls_active}" ) + # Handle legitimation for password-protected PLCs + if password is not None and self._tls_active: + logger.info("Performing PLC legitimation (password authentication)") + await self.authenticate(password) + # Probe S7CommPlus data operations if not await self._probe_s7commplus_data(): logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") @@ -145,6 +203,116 @@ async def connect( await self.disconnect() raise + async def authenticate(self, password: str, username: str = "") -> None: + """Perform PLC password authentication (legitimation). + + Must be called after connect() and before data operations on + password-protected PLCs. Requires TLS to be active (V2+). + + Args: + password: PLC password + username: Username for new-style auth (optional) + """ + if not self._connected: + raise RuntimeError("Not connected") + + if not self._tls_active: + raise RuntimeError("Legitimation requires TLS. Connect with use_tls=True.") + + # Step 1: Get challenge from PLC + challenge = await self._get_legitimation_challenge() + logger.info(f"Received legitimation challenge ({len(challenge)} bytes)") + + # Step 2: Build response (auto-detect legacy vs new) + from .legitimation import build_legacy_response, build_new_response + + if username and self._oms_secret is not None: + response_data = build_new_response(password, challenge, self._oms_secret, username) + await self._send_legitimation_new(response_data) + elif self._oms_secret is not None: + try: + response_data = build_new_response(password, challenge, self._oms_secret, "") + await self._send_legitimation_new(response_data) + except NotImplementedError: + response_data = build_legacy_response(password, challenge) + await self._send_legitimation_legacy(response_data) + else: + logger.info("OMS secret not available, using legacy legitimation") + response_data = build_legacy_response(password, challenge) + await self._send_legitimation_legacy(response_data) + + logger.info("PLC legitimation completed successfully") + + async def _get_legitimation_challenge(self) -> bytes: + """Request legitimation challenge from PLC.""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_REQUEST) + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.GET_VAR_SUBSTREAMED, bytes(payload)) + + offset = 0 + return_value, consumed = decode_uint64_vlq(resp_payload, offset) + offset += consumed + + if return_value != 0: + raise RuntimeError(f"GetVarSubStreamed for challenge failed: return_value={return_value}") + + if offset + 2 > len(resp_payload): + raise RuntimeError("Challenge response too short") + + _flags = resp_payload[offset] + datatype = resp_payload[offset + 1] + offset += 2 + + if datatype == DataType.BLOB: + length, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + length]) + else: + count, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + count]) + + async def _send_legitimation_new(self, encrypted_response: bytes) -> None: + """Send new-style legitimation response (AES-256-CBC encrypted).""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.LEGITIMATE) + payload += bytes([0x00, DataType.BLOB]) + payload += encode_uint32_vlq(len(encrypted_response)) + payload += encrypted_response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + raise RuntimeError(f"Legitimation rejected by PLC: return_value={return_value}") + + async def _send_legitimation_legacy(self, response: bytes) -> None: + """Send legacy legitimation response (SHA-1 XOR).""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_RESPONSE) + payload += bytes([0x10, DataType.USINT]) + payload += encode_uint32_vlq(len(response)) + payload += response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + raise RuntimeError(f"Legacy legitimation rejected by PLC: return_value={return_value}") + async def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations.""" try: @@ -198,6 +366,9 @@ async def disconnect(self) -> None: self._with_integrity_id = False self._integrity_id_read = 0 self._integrity_id_write = 0 + self._tls_active = False + self._oms_secret = None + self._server_session_version = None if self._writer: try: @@ -407,6 +578,65 @@ async def _init_ssl(self) -> None: logger.debug(f"InitSSL response received, version=V{version}") + async def _activate_tls( + self, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + ) -> None: + """Activate TLS 1.3 over the COTP connection. + + Called after InitSSL and before CreateObject. Wraps the underlying + asyncio streams with TLS. + """ + if self._writer is None: + raise RuntimeError("Cannot activate TLS: not connected") + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; + # set_ciphers() only controls TLS 1.2 and below. + try: + ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + except ssl.SSLError: + pass + + if tls_cert and tls_key: + ctx.load_cert_chain(tls_cert, tls_key) + + if tls_ca: + ctx.load_verify_locations(tls_ca) + else: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + # Upgrade the connection to TLS. + # StreamWriter.start_tls() is the clean API (Python 3.11+). + # For Python 3.10, fall back to loop.start_tls() with the existing protocol. + if hasattr(self._writer, "start_tls"): + await self._writer.start_tls(ctx, server_hostname=self._host) + else: + transport = self._writer.transport + protocol = transport.get_protocol() + loop = asyncio.get_event_loop() + new_transport = await loop.start_tls(transport, protocol, ctx, server_hostname=self._host) + # Update writer's internal transport reference + self._writer._transport = new_transport + + self._tls_active = True + + # Extract OMS exporter secret for legitimation key derivation + ssl_object = self._writer.transport.get_extra_info("ssl_object") + if ssl_object is not None: + try: + self._oms_secret = ssl_object.export_keying_material("EXPERIMENTAL_OMS", 32, None) + logger.debug("OMS exporter secret extracted from TLS session") + except (AttributeError, ssl.SSLError) as e: + logger.warning(f"Could not extract OMS exporter secret: {e}") + self._oms_secret = None + + logger.info("TLS activated on async COTP connection") + async def _create_session(self) -> None: """Send CreateObject to establish S7CommPlus session.""" seq_num = self._next_sequence_number() @@ -455,7 +685,7 @@ async def _create_session(self) -> None: request += bytes([ElementID.TERMINATING_OBJECT]) request += struct.pack(">I", 0) - # Frame header + trailer + # Frame header + trailer (always V1 for CreateObject) frame = encode_header(ProtocolVersion.V1, len(request)) + request frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) await self._send_cotp_dt(frame) @@ -470,6 +700,178 @@ async def _create_session(self) -> None: self._session_id = struct.unpack_from(">I", response, 9)[0] self._protocol_version = version + logger.debug(f"Session created: id=0x{self._session_id:08X}, version=V{version}") + + # Parse response payload to extract ServerSessionVersion + self._parse_create_object_response(response[14:]) + + def _parse_create_object_response(self, payload: bytes) -> None: + """Parse CreateObject response to extract ServerSessionVersion.""" + offset = 0 + while offset < len(payload): + tag = payload[offset] + + if tag == ElementID.ATTRIBUTE: + offset += 1 + if offset >= len(payload): + break + attr_id, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + + if attr_id == ObjectId.SERVER_SESSION_VERSION: + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + if datatype in (DataType.UDINT, DataType.DWORD): + value, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + self._server_session_version = value + logger.info(f"ServerSessionVersion = {value}") + return + else: + logger.debug(f"ServerSessionVersion has unexpected type {datatype:#04x}") + else: + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + offset = self._skip_typed_value(payload, offset, datatype, _flags) + + elif tag == ElementID.START_OF_OBJECT: + offset += 1 + if offset + 4 > len(payload): + break + offset += 4 # RelationId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassFlags + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # AttributeId + + elif tag == ElementID.TERMINATING_OBJECT: + offset += 1 + elif tag == 0x00: + offset += 1 + else: + offset += 1 + + logger.debug("ServerSessionVersion not found in CreateObject response") + + def _skip_typed_value(self, data: bytes, offset: int, datatype: int, flags: int) -> int: + """Skip over a typed value in the PObject tree.""" + is_array = bool(flags & 0x10) + + if is_array: + if offset >= len(data): + return offset + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + elem_size = _element_size(datatype) + if elem_size > 0: + offset += count * elem_size + else: + for _ in range(count): + if offset >= len(data): + break + _, consumed = decode_uint32_vlq(data, offset) + offset += consumed + return offset + + if datatype == DataType.NULL: + return offset + elif datatype in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + return offset + 1 + elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): + return offset + 2 + elif datatype in (DataType.UDINT, DataType.DWORD, DataType.AID, DataType.DINT): + _, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + elif datatype in (DataType.ULINT, DataType.LWORD, DataType.LINT): + _, consumed = decode_uint64_vlq(data, offset) + return offset + consumed + elif datatype == DataType.REAL: + return offset + 4 + elif datatype == DataType.LREAL: + return offset + 8 + elif datatype == DataType.TIMESTAMP: + return offset + 8 + elif datatype == DataType.TIMESPAN: + _, consumed = decode_uint64_vlq(data, offset) + return offset + consumed + elif datatype == DataType.RID: + return offset + 4 + elif datatype in (DataType.BLOB, DataType.WSTRING): + length, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + length + elif datatype == DataType.STRUCT: + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + for _ in range(count): + if offset + 2 > len(data): + break + sub_flags = data[offset] + sub_type = data[offset + 1] + offset += 2 + offset = self._skip_typed_value(data, offset, sub_type, sub_flags) + return offset + else: + return offset + + async def _setup_session(self) -> None: + """Send SetMultiVariables to echo ServerSessionVersion back to the PLC.""" + if self._server_session_version is None: + return + + seq_num = self._next_sequence_number() + + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + self._session_id, + 0x36, + ) + + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) + payload += encode_uint32_vlq(1) + payload += bytes([0x00, DataType.UDINT]) + payload += encode_uint32_vlq(self._server_session_version) + payload += bytes([0x00]) + payload += encode_object_qualifier() + payload += struct.pack(">I", 0) + + request += bytes(payload) + + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + await self._send_cotp_dt(frame) + + response_data = await self._recv_cotp_dt() + version, data_length, consumed = decode_header(response_data) + response = response_data[consumed : consumed + data_length] + + if len(response) < 14: + raise RuntimeError("SetupSession response too short") + + resp_payload = response[14:] + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value != 0: + logger.warning(f"SetupSession: PLC returned error {return_value}") + else: + logger.info("Session setup completed successfully") + async def _delete_session(self) -> None: """Send DeleteObject to close the session.""" seq_num = self._next_sequence_number() diff --git a/snap7/s7commplus/client.py b/snap7/s7commplus/client.py index 44112c9c..341c879f 100644 --- a/snap7/s7commplus/client.py +++ b/snap7/s7commplus/client.py @@ -14,7 +14,7 @@ the client transparently falls back to the legacy S7 protocol for data block read/write operations. -Status: V1 and V2 connections are functional. V3/TLS authentication planned. +Status: V1, V2, and V3 (including TLS) connections are functional. Reference: thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) """ @@ -24,7 +24,7 @@ from typing import Any, Optional from .connection import S7CommPlusConnection -from .protocol import FunctionCode, Ids +from .protocol import FunctionCode, Ids, ProtocolVersion from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .codec import ( encode_item_address, @@ -148,10 +148,12 @@ def connect( logger.info("Performing PLC legitimation (password authentication)") self._connection.authenticate(password) - # Probe S7CommPlus data operations with a minimal request - if not self._probe_s7commplus_data(): - logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") - self._setup_legacy_fallback() + # Probe S7CommPlus data operations with a minimal request. + # Skip probe for V2+ with TLS: TLS handshake confirms S7CommPlus support. + if self._connection.protocol_version < ProtocolVersion.V2: + if not self._probe_s7commplus_data(): + logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") + self._setup_legacy_fallback() def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations. diff --git a/snap7/s7commplus/connection.py b/snap7/s7commplus/connection.py index a60b44a0..ae44506d 100644 --- a/snap7/s7commplus/connection.py +++ b/snap7/s7commplus/connection.py @@ -85,8 +85,7 @@ class S7CommPlusConnection: - Version-appropriate authentication (V1/V2/V3/TLS) - Frame send/receive (TLS-encrypted when using V17+ firmware) - Currently implements V1 authentication. V2/V3/TLS authentication - layers are planned for future development. + Supports V1, V2, and V3 (including TLS) authentication. """ def __init__( @@ -202,21 +201,19 @@ def connect( logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") # Step 6: Version-specific post-setup - if self._protocol_version >= ProtocolVersion.V3: - if not use_tls: - logger.warning( - "PLC reports V3 protocol but TLS is not enabled. Connection may not work without use_tls=True." - ) - elif self._protocol_version == ProtocolVersion.V2: + if self._protocol_version >= ProtocolVersion.V2: if not self._tls_active: from ..error import S7ConnectionError - raise S7ConnectionError("PLC reports V2 protocol but TLS is not active. V2 requires TLS. Use use_tls=True.") + raise S7ConnectionError( + f"PLC reports V{self._protocol_version} protocol but TLS is not active. " + "V2/V3 requires TLS. Use use_tls=True." + ) # Enable IntegrityId tracking for V2+ self._with_integrity_id = True self._integrity_id_read = 0 self._integrity_id_write = 0 - logger.info("V2 IntegrityId tracking enabled") + logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") # V1: No further authentication needed after CreateObject self._connected = True @@ -251,7 +248,7 @@ def authenticate(self, password: str, username: str = "") -> None: raise S7ConnectionError("Not connected") - if not self._tls_active or self._oms_secret is None: + if not self._tls_active: from ..error import S7ConnectionError raise S7ConnectionError("Legitimation requires TLS. Connect with use_tls=True.") @@ -263,11 +260,11 @@ def authenticate(self, password: str, username: str = "") -> None: # Step 2: Build response (auto-detect legacy vs new) from .legitimation import build_legacy_response, build_new_response - if username: + if username and self._oms_secret is not None: # New-style auth with username always uses AES-256-CBC response_data = build_new_response(password, challenge, self._oms_secret, username) self._send_legitimation_new(response_data) - else: + elif self._oms_secret is not None: # Try new-style first, fall back to legacy SHA-1 XOR try: response_data = build_new_response(password, challenge, self._oms_secret, "") @@ -276,6 +273,12 @@ def authenticate(self, password: str, username: str = "") -> None: # cryptography package not available, use legacy response_data = build_legacy_response(password, challenge) self._send_legitimation_legacy(response_data) + else: + # No OMS secret available (export_keying_material not supported), + # fall back to legacy SHA-1 XOR authentication + logger.info("OMS secret not available, using legacy legitimation") + response_data = build_legacy_response(password, challenge) + self._send_legitimation_legacy(response_data) logger.info("PLC legitimation completed successfully") @@ -1013,7 +1016,13 @@ def _setup_ssl_context( """ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.minimum_version = ssl.TLSVersion.TLSv1_3 - ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; + # set_ciphers() only controls TLS 1.2 and below. We try to set + # preferred ciphers but ignore failures (e.g. OpenSSL 3.x). + try: + ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + except ssl.SSLError: + pass if cert_path and key_path: ctx.load_cert_chain(cert_path, key_path) diff --git a/snap7/s7commplus/server.py b/snap7/s7commplus/server.py index 2af0d769..6719bf71 100644 --- a/snap7/s7commplus/server.py +++ b/snap7/s7commplus/server.py @@ -10,7 +10,7 @@ - Internal PLC memory model with thread-safe access - V2 protocol emulation with TLS and IntegrityId tracking -Supports both V1 (no TLS) and V2 (TLS + IntegrityId) emulation. +Supports V1 (no TLS), V2 (TLS + IntegrityId), and V3 (TLS + IntegrityId) emulation. Usage:: @@ -186,17 +186,21 @@ class S7CommPlusServer: Emulates an S7-1200/1500 PLC with: - Internal data block storage with named variables - - S7CommPlus protocol handling (V1 and V2) - - V2 TLS support with IntegrityId tracking + - S7CommPlus protocol handling (V1, V2, V3) + - V2/V3 TLS support with IntegrityId tracking + - Legitimation (password authentication) emulation - Multi-client support (threaded) - CPU state management """ - def __init__(self, protocol_version: int = ProtocolVersion.V1) -> None: + def __init__(self, protocol_version: int = ProtocolVersion.V1, password: Optional[str] = None) -> None: self._data_blocks: dict[int, DataBlock] = {} self._cpu_state = CPUState.RUN self._protocol_version = protocol_version + self._password = password self._next_session_id = 1 + # Per-client authentication state (session_id -> authenticated) + self._authenticated_sessions: dict[int, bool] = {} self._server_socket: Optional[socket.socket] = None self._server_thread: Optional[threading.Thread] = None @@ -205,7 +209,7 @@ def __init__(self, protocol_version: int = ProtocolVersion.V1) -> None: self._lock = threading.Lock() self._event_callback: Optional[Callable[..., None]] = None - # TLS configuration (V2) + # TLS configuration (V2/V3) self._ssl_context: Optional[ssl.SSLContext] = None self._use_tls: bool = False @@ -364,8 +368,12 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - session_id = 0 tls_activated = False # Per-client IntegrityId tracking (V2+) + # IntegrityId tracking starts AFTER the session setup (SetMultiVariables + # echoing ServerSessionVersion). The first SetMultiVariables after + # CreateObject is the session setup and doesn't include IntegrityId. integrity_id_read = 0 integrity_id_write = 0 + integrity_id_active = False while self._running: try: @@ -375,7 +383,9 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - break # Process the S7CommPlus request - response = self._process_request(data, session_id, integrity_id_read, integrity_id_write) + response = self._process_request( + data, session_id, integrity_id_read, integrity_id_write, integrity_id_active + ) if response is not None: # Check if session ID was assigned @@ -412,7 +422,12 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - payload = data[hdr_consumed:] if len(payload) >= 14: func_code = struct.unpack_from(">H", payload, 3)[0] - if func_code in READ_FUNCTION_CODES: + if not integrity_id_active: + # First SetMultiVariables after CreateObject is session setup + # (no IntegrityId). Activate tracking after it. + if func_code == FunctionCode.SET_MULTI_VARIABLES: + integrity_id_active = True + elif func_code in READ_FUNCTION_CODES: integrity_id_read = (integrity_id_read + 1) & 0xFFFFFFFF elif func_code not in ( FunctionCode.INIT_SSL, @@ -522,6 +537,7 @@ def _process_request( session_id: int, integrity_id_read: int = 0, integrity_id_write: int = 0, + integrity_id_active: bool = False, ) -> Optional[bytes]: """Process an S7CommPlus request and return a response.""" if len(data) < 4: @@ -547,11 +563,14 @@ def _process_request( seq_num = struct.unpack_from(">H", payload, 7)[0] req_session_id = struct.unpack_from(">I", payload, 9)[0] - # For V2+, skip IntegrityId after the 14-byte header + # For V2+, skip IntegrityId after the 14-byte header. + # IntegrityId is only present after session setup is complete + # (integrity_id_active=True). The first SetMultiVariables after + # CreateObject is the session setup and doesn't include IntegrityId. request_offset = 14 if ( - self._protocol_version >= ProtocolVersion.V2 - and session_id != 0 + integrity_id_active + and self._protocol_version >= ProtocolVersion.V2 and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) ): if request_offset < len(payload): @@ -566,14 +585,41 @@ def _process_request( return self._handle_create_object(seq_num, request_data) elif function_code == FunctionCode.DELETE_OBJECT: return self._handle_delete_object(seq_num, req_session_id) + elif function_code == FunctionCode.GET_VAR_SUBSTREAMED: + response = self._handle_get_var_substreamed(seq_num, req_session_id, request_data) + elif function_code == FunctionCode.SET_VARIABLE: + response = self._handle_set_variable(seq_num, req_session_id, request_data) elif function_code == FunctionCode.EXPLORE: - return self._handle_explore(seq_num, req_session_id, request_data) + if not self._check_authenticated(req_session_id): + response = self._build_error_response(seq_num, req_session_id, function_code) + else: + response = self._handle_explore(seq_num, req_session_id, request_data) elif function_code == FunctionCode.GET_MULTI_VARIABLES: - return self._handle_get_multi_variables(seq_num, req_session_id, request_data) + if not self._check_authenticated(req_session_id): + response = self._build_error_response(seq_num, req_session_id, function_code) + else: + response = self._handle_get_multi_variables(seq_num, req_session_id, request_data) elif function_code == FunctionCode.SET_MULTI_VARIABLES: - return self._handle_set_multi_variables(seq_num, req_session_id, request_data) + # Auth check is inside the handler: session setup must bypass auth + response = self._handle_set_multi_variables( + seq_num, req_session_id, request_data, self._check_authenticated(req_session_id) + ) else: - return self._build_error_response(seq_num, req_session_id, function_code) + response = self._build_error_response(seq_num, req_session_id, function_code) + + # For V2+, insert IntegrityId right after the 14-byte response header. + # The client expects IntegrityId at offset 14 in the response, mirroring + # the request format. Only insert when IntegrityId tracking is active. + if ( + integrity_id_active + and self._protocol_version >= ProtocolVersion.V2 + and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) + and response is not None + and len(response) >= 14 + ): + response = response[:14] + encode_uint32_vlq(0) + response[14:] + + return response def _build_response_header( self, @@ -798,16 +844,30 @@ def _handle_get_multi_variables(self, seq_num: int, session_id: int, request_dat # Terminate error list response += encode_uint32_vlq(0) - # IntegrityId - response += encode_uint32_vlq(0) - return bytes(response) - def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + def _handle_set_multi_variables( + self, seq_num: int, session_id: int, request_data: bytes, is_authenticated: bool = True + ) -> bytes: """Handle SetMultiVariables -- write variables to data blocks. + Also handles session setup (SetMultiVariables echoing ServerSessionVersion). + Session setup is detected by InObjectId matching the session_id. + Session setup always bypasses auth; data writes require authentication. + Reference: thomas-v2/S7CommPlusDriver/Core/SetMultiVariablesRequest.cs """ + # Check if this is a session setup write (InObjectId = session_id) + if len(request_data) >= 4: + in_object_id = struct.unpack_from(">I", request_data, 0)[0] + if in_object_id == session_id and session_id != 0: + # Session setup: just acknowledge success (no auth required) + return self._build_set_multi_response(seq_num, session_id, []) + + # For data writes, require authentication + if not is_authenticated: + return self._build_error_response(seq_num, session_id, FunctionCode.SET_MULTI_VARIABLES) + response = bytearray() response += struct.pack( ">BHHHHIB", @@ -843,8 +903,98 @@ def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_dat # Terminate error list response += encode_uint32_vlq(0) - # IntegrityId - response += encode_uint32_vlq(0) + return bytes(response) + + def _build_set_multi_response( + self, seq_num: int, session_id: int, errors: list[tuple[int, int]] + ) -> bytes: + """Build a SetMultiVariables response.""" + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + session_id, + 0x00, + ) + response += encode_uint64_vlq(0) # ReturnValue: success + for err_item, err_code in errors: + response += encode_uint32_vlq(err_item) + response += encode_uint64_vlq(err_code) + response += encode_uint32_vlq(0) # Terminate error list + return bytes(response) + + def _check_authenticated(self, session_id: int) -> bool: + """Check if a session is authenticated (or no password is required).""" + if self._password is None: + return True + return self._authenticated_sessions.get(session_id, False) + + def _handle_get_var_substreamed(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle GetVarSubStreamed -- used for legitimation challenge request. + + Returns a 20-byte random challenge when the client requests + ServerSessionRequest (address 303). + """ + import os + + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.GET_VAR_SUBSTREAMED, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Value: 20-byte challenge as BLOB + challenge = os.urandom(20) + response += bytes([0x00, DataType.BLOB]) + response += encode_uint32_vlq(len(challenge)) + response += challenge + + # Trailing padding + response += struct.pack(">I", 0) + + return bytes(response) + + def _handle_set_variable(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle SetVariable -- used for legitimation response. + + Accepts the client's authentication response and marks the session + as authenticated. In this emulator, any response is accepted (we + don't verify the actual crypto). + """ + # Mark session as authenticated + self._authenticated_sessions[session_id] = True + logger.debug(f"Session {session_id} authenticated via SetVariable") + + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.SET_VARIABLE, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Trailing padding + response += struct.pack(">I", 0) return bytes(response) diff --git a/tests/test_s7commplus_tls.py b/tests/test_s7commplus_tls.py new file mode 100644 index 00000000..873d2a9f --- /dev/null +++ b/tests/test_s7commplus_tls.py @@ -0,0 +1,417 @@ +"""Integration tests for S7CommPlus V2/V3 with TLS. + +Tests the complete TLS connection flow including: +- V2 server + client with TLS and IntegrityId tracking +- V3 server + client with TLS +- Legitimation (password authentication) flow +- Async client with TLS +- Error handling for missing TLS + +Requires the `cryptography` package for self-signed certificate generation. +""" + +import struct +import tempfile +import time +from collections.abc import Generator +from pathlib import Path + +import pytest + +from snap7.s7commplus.server import S7CommPlusServer +from snap7.s7commplus.client import S7CommPlusClient +from snap7.s7commplus.protocol import ProtocolVersion + +try: + import ipaddress + + from cryptography import x509 + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + import datetime + + _has_cryptography = True +except ImportError: + _has_cryptography = False + +pytestmark = pytest.mark.skipif(not _has_cryptography, reason="requires cryptography package") + +# Use high ports to avoid conflicts +V2_TEST_PORT = 11130 +V3_TEST_PORT = 11131 +V2_AUTH_PORT = 11132 +V3_AUTH_PORT = 11133 + + +@pytest.fixture(scope="module") +def tls_certs() -> Generator[dict[str, str], None, None]: + """Generate self-signed TLS certificates for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + cert_path = str(Path(tmpdir) / "server.crt") + key_path = str(Path(tmpdir) / "server.key") + + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + with open(key_path, "wb") as f: + f.write( + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.NoEncryption(), + ) + ) + + yield {"cert": cert_path, "key": key_path} + + +def _make_server( + protocol_version: int, + password: str | None = None, +) -> S7CommPlusServer: + """Create and configure an S7CommPlus server with test data blocks.""" + srv = S7CommPlusServer(protocol_version=protocol_version, password=password) + + srv.register_db( + 1, + { + "temperature": ("Real", 0), + "pressure": ("Real", 4), + "running": ("Bool", 8), + "count": ("DInt", 10), + }, + ) + srv.register_raw_db(2, bytearray(256)) + + # Pre-populate DB1 + db1 = srv.get_db(1) + assert db1 is not None + struct.pack_into(">f", db1.data, 0, 23.5) + struct.pack_into(">f", db1.data, 4, 1.013) + db1.data[8] = 1 + struct.pack_into(">i", db1.data, 10, 42) + + return srv + + +@pytest.fixture() +def v2_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V2 server with TLS.""" + srv = _make_server(ProtocolVersion.V2) + srv.start(port=V2_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v3_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V3 server with TLS.""" + srv = _make_server(ProtocolVersion.V3) + srv.start(port=V3_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v2_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V2 server with TLS and password authentication.""" + srv = _make_server(ProtocolVersion.V2, password="secret123") + srv.start(port=V2_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v3_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V3 server with TLS and password authentication.""" + srv = _make_server(ProtocolVersion.V3, password="secret123") + srv.start(port=V3_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +class TestV2TLS: + """Test V2 protocol with TLS.""" + + def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V2 + client.disconnect() + assert not client.connected + + def test_read_real(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + client.db_write(1, 0, struct.pack(">f", 99.9)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 99.9) < 0.1 + finally: + client.disconnect() + + def test_multi_read(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + results = client.db_read_multi([ + (1, 0, 4), + (1, 4, 4), + (2, 0, 4), + ]) + assert len(results) == 3 + temp = struct.unpack(">f", results[0])[0] + assert abs(temp - 23.5) < 0.001 + finally: + client.disconnect() + + def test_explore(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + response = client.explore() + assert len(response) > 0 + finally: + client.disconnect() + + +class TestV3TLS: + """Test V3 protocol with TLS.""" + + def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V3 + client.disconnect() + assert not client.connected + + def test_read_real(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_write_and_read_back(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + client.db_write(1, 0, struct.pack(">f", 88.8)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 88.8) < 0.1 + finally: + client.disconnect() + + def test_multi_read(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + results = client.db_read_multi([ + (1, 0, 4), + (1, 10, 4), + ]) + assert len(results) == 2 + finally: + client.disconnect() + + def test_data_persists_across_clients(self, v3_server: S7CommPlusServer) -> None: + c1 = S7CommPlusClient() + c1.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + c1.db_write(2, 0, b"\xca\xfe\xba\xbe") + c1.disconnect() + + c2 = S7CommPlusClient() + c2.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + data = c2.db_read(2, 0, 4) + c2.disconnect() + + assert data == b"\xca\xfe\xba\xbe" + + +class TestV2WithoutTLS: + """Test that V2/V3 connections fail without TLS.""" + + def test_v2_server_requires_tls_on_client(self, tls_certs: dict[str, str]) -> None: + """V2 server reports V2 version; client should raise if no TLS.""" + # Start a V2 server WITHOUT TLS (so client can connect but gets V2 version) + srv = _make_server(ProtocolVersion.V2) + srv.start(port=V2_TEST_PORT + 10) + time.sleep(0.1) + try: + client = S7CommPlusClient() + with pytest.raises(Exception): + # The server reports V2 but client didn't use TLS + client.connect("127.0.0.1", port=V2_TEST_PORT + 10) + finally: + srv.stop() + + +class TestLegitimation: + """Test password authentication (legitimation).""" + + def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + try: + # Should be able to read after authentication + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_v3_connect_with_password(self, v3_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_v2_without_password_on_protected_server(self, v2_auth_server: S7CommPlusServer) -> None: + """Connecting without password to a password-protected server should fail data ops.""" + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True) + try: + # Server should reject data operations without authentication + with pytest.raises(Exception): + client.db_read(1, 0, 4) + finally: + client.disconnect() + + def test_v2_write_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + try: + client.db_write(1, 0, struct.pack(">f", 55.5)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 55.5) < 0.1 + finally: + client.disconnect() + + +@pytest.mark.asyncio +class TestAsyncV2TLS: + """Test async client with V2 TLS.""" + + async def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V2 + assert client.tls_active + await client.disconnect() + assert not client.connected + + async def test_read_real(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + + async def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + await client.db_write(1, 0, struct.pack(">f", 77.7)) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 77.7) < 0.1 + + +@pytest.mark.asyncio +class TestAsyncV3TLS: + """Test async client with V3 TLS.""" + + async def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + assert client.connected + assert client.protocol_version == ProtocolVersion.V3 + assert client.tls_active + await client.disconnect() + + async def test_read_write(self, v3_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + await client.db_write(2, 0, b"\xde\xad") + data = await client.db_read(2, 0, 2) + assert data == b"\xde\xad" + + +@pytest.mark.asyncio +class TestAsyncLegitimation: + """Test async client with password authentication.""" + + async def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 From bf2f828a629e61c74e7125feb024f3b781971e96 Mon Sep 17 00:00:00 2001 From: Gijs Molenaar Date: Wed, 25 Mar 2026 14:22:51 +0200 Subject: [PATCH 3/4] Revert "Implement S7CommPlus V2/V3 protocol with TLS and IntegrityId" This reverts commit 60be2853e2fa2560389ecad65690b49c517c14bd. --- snap7/s7commplus/async_client.py | 408 +----------------------------- snap7/s7commplus/client.py | 14 +- snap7/s7commplus/connection.py | 37 ++- snap7/s7commplus/server.py | 190 ++------------ tests/test_s7commplus_tls.py | 417 ------------------------------- 5 files changed, 43 insertions(+), 1023 deletions(-) delete mode 100644 tests/test_s7commplus_tls.py diff --git a/snap7/s7commplus/async_client.py b/snap7/s7commplus/async_client.py index e58e9fc4..e6a46fe2 100644 --- a/snap7/s7commplus/async_client.py +++ b/snap7/s7commplus/async_client.py @@ -4,10 +4,6 @@ Provides the same API as S7CommPlusClient but using asyncio for non-blocking I/O. Uses asyncio.Lock for concurrent safety. -Supports all S7CommPlus protocol versions (V1/V2/V3/TLS). The protocol -version is auto-detected from the PLC's CreateObject response during -connection setup. - When a PLC does not support S7CommPlus data operations, the client transparently falls back to the legacy S7 protocol for data block read/write operations (using synchronous calls in an executor). @@ -22,7 +18,6 @@ import asyncio import logging -import ssl import struct from typing import Any, Optional @@ -30,7 +25,6 @@ DataType, ElementID, FunctionCode, - LegitimationId, ObjectId, Opcode, ProtocolVersion, @@ -41,7 +35,6 @@ from .codec import encode_header, decode_header, encode_typed_value, encode_object_qualifier from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .client import _build_read_payload, _parse_read_response, _build_write_payload, _parse_write_response -from .connection import _element_size logger = logging.getLogger(__name__) @@ -54,7 +47,7 @@ class S7CommPlusAsyncClient: """Async S7CommPlus client for S7-1200/1500 PLCs. - Supports V1, V2, and V3 protocols (including TLS). + Supports V1 and V2 protocols. V3/TLS planned for future. Uses asyncio for all I/O operations and asyncio.Lock for concurrent safety when shared between multiple coroutines. @@ -83,13 +76,6 @@ def __init__(self) -> None: self._integrity_id_write: int = 0 self._with_integrity_id: bool = False - # TLS state - self._tls_active: bool = False - self._oms_secret: Optional[bytes] = None - - # Session setup - self._server_session_version: Optional[int] = None - @property def connected(self) -> bool: if self._use_legacy_data and self._legacy_client is not None: @@ -109,22 +95,12 @@ def using_legacy_fallback(self) -> bool: """Whether the client is using legacy S7 protocol for data operations.""" return self._use_legacy_data - @property - def tls_active(self) -> bool: - """Whether TLS encryption is active on this connection.""" - return self._tls_active - async def connect( self, host: str, port: int = 102, rack: int = 0, slot: int = 1, - use_tls: bool = False, - tls_cert: Optional[str] = None, - tls_key: Optional[str] = None, - tls_ca: Optional[str] = None, - password: Optional[str] = None, ) -> None: """Connect to an S7-1200/1500 PLC. @@ -136,11 +112,6 @@ async def connect( port: TCP port (default 102) rack: PLC rack number slot: PLC slot number - use_tls: Whether to activate TLS (required for V2/V3) - tls_cert: Path to client TLS certificate (PEM) - tls_key: Path to client private key (PEM) - tls_ca: Path to CA certificate for PLC verification (PEM) - password: PLC password for legitimation (V2+ with TLS) """ self._host = host self._port = port @@ -157,43 +128,14 @@ async def connect( # InitSSL handshake await self._init_ssl() - # TLS activation (between InitSSL and CreateObject) - if use_tls: - await self._activate_tls(tls_cert=tls_cert, tls_key=tls_key, tls_ca=tls_ca) - # S7CommPlus session setup await self._create_session() - # Echo ServerSessionVersion back to complete handshake - if self._server_session_version is not None: - await self._setup_session() - else: - logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") - - # Version-specific post-setup - if self._protocol_version >= ProtocolVersion.V2: - if not self._tls_active: - raise RuntimeError( - f"PLC reports V{self._protocol_version} protocol but TLS is not active. " - "V2/V3 requires TLS. Use use_tls=True." - ) - self._with_integrity_id = True - self._integrity_id_read = 0 - self._integrity_id_write = 0 - logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") - self._connected = True logger.info( - f"Async S7CommPlus connected to {host}:{port}, " - f"version=V{self._protocol_version}, session={self._session_id}, " - f"tls={self._tls_active}" + f"Async S7CommPlus connected to {host}:{port}, version=V{self._protocol_version}, session={self._session_id}" ) - # Handle legitimation for password-protected PLCs - if password is not None and self._tls_active: - logger.info("Performing PLC legitimation (password authentication)") - await self.authenticate(password) - # Probe S7CommPlus data operations if not await self._probe_s7commplus_data(): logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") @@ -203,116 +145,6 @@ async def connect( await self.disconnect() raise - async def authenticate(self, password: str, username: str = "") -> None: - """Perform PLC password authentication (legitimation). - - Must be called after connect() and before data operations on - password-protected PLCs. Requires TLS to be active (V2+). - - Args: - password: PLC password - username: Username for new-style auth (optional) - """ - if not self._connected: - raise RuntimeError("Not connected") - - if not self._tls_active: - raise RuntimeError("Legitimation requires TLS. Connect with use_tls=True.") - - # Step 1: Get challenge from PLC - challenge = await self._get_legitimation_challenge() - logger.info(f"Received legitimation challenge ({len(challenge)} bytes)") - - # Step 2: Build response (auto-detect legacy vs new) - from .legitimation import build_legacy_response, build_new_response - - if username and self._oms_secret is not None: - response_data = build_new_response(password, challenge, self._oms_secret, username) - await self._send_legitimation_new(response_data) - elif self._oms_secret is not None: - try: - response_data = build_new_response(password, challenge, self._oms_secret, "") - await self._send_legitimation_new(response_data) - except NotImplementedError: - response_data = build_legacy_response(password, challenge) - await self._send_legitimation_legacy(response_data) - else: - logger.info("OMS secret not available, using legacy legitimation") - response_data = build_legacy_response(password, challenge) - await self._send_legitimation_legacy(response_data) - - logger.info("PLC legitimation completed successfully") - - async def _get_legitimation_challenge(self) -> bytes: - """Request legitimation challenge from PLC.""" - payload = bytearray() - payload += struct.pack(">I", self._session_id) - payload += encode_uint32_vlq(1) - payload += encode_uint32_vlq(1) - payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_REQUEST) - payload += struct.pack(">I", 0) - - resp_payload = await self._send_request(FunctionCode.GET_VAR_SUBSTREAMED, bytes(payload)) - - offset = 0 - return_value, consumed = decode_uint64_vlq(resp_payload, offset) - offset += consumed - - if return_value != 0: - raise RuntimeError(f"GetVarSubStreamed for challenge failed: return_value={return_value}") - - if offset + 2 > len(resp_payload): - raise RuntimeError("Challenge response too short") - - _flags = resp_payload[offset] - datatype = resp_payload[offset + 1] - offset += 2 - - if datatype == DataType.BLOB: - length, consumed = decode_uint32_vlq(resp_payload, offset) - offset += consumed - return bytes(resp_payload[offset : offset + length]) - else: - count, consumed = decode_uint32_vlq(resp_payload, offset) - offset += consumed - return bytes(resp_payload[offset : offset + count]) - - async def _send_legitimation_new(self, encrypted_response: bytes) -> None: - """Send new-style legitimation response (AES-256-CBC encrypted).""" - payload = bytearray() - payload += struct.pack(">I", self._session_id) - payload += encode_uint32_vlq(1) - payload += encode_uint32_vlq(LegitimationId.LEGITIMATE) - payload += bytes([0x00, DataType.BLOB]) - payload += encode_uint32_vlq(len(encrypted_response)) - payload += encrypted_response - payload += struct.pack(">I", 0) - - resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) - - if len(resp_payload) >= 1: - return_value, _ = decode_uint64_vlq(resp_payload, 0) - if return_value < 0: - raise RuntimeError(f"Legitimation rejected by PLC: return_value={return_value}") - - async def _send_legitimation_legacy(self, response: bytes) -> None: - """Send legacy legitimation response (SHA-1 XOR).""" - payload = bytearray() - payload += struct.pack(">I", self._session_id) - payload += encode_uint32_vlq(1) - payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_RESPONSE) - payload += bytes([0x10, DataType.USINT]) - payload += encode_uint32_vlq(len(response)) - payload += response - payload += struct.pack(">I", 0) - - resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) - - if len(resp_payload) >= 1: - return_value, _ = decode_uint64_vlq(resp_payload, 0) - if return_value < 0: - raise RuntimeError(f"Legacy legitimation rejected by PLC: return_value={return_value}") - async def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations.""" try: @@ -366,9 +198,6 @@ async def disconnect(self) -> None: self._with_integrity_id = False self._integrity_id_read = 0 self._integrity_id_write = 0 - self._tls_active = False - self._oms_secret = None - self._server_session_version = None if self._writer: try: @@ -578,65 +407,6 @@ async def _init_ssl(self) -> None: logger.debug(f"InitSSL response received, version=V{version}") - async def _activate_tls( - self, - tls_cert: Optional[str] = None, - tls_key: Optional[str] = None, - tls_ca: Optional[str] = None, - ) -> None: - """Activate TLS 1.3 over the COTP connection. - - Called after InitSSL and before CreateObject. Wraps the underlying - asyncio streams with TLS. - """ - if self._writer is None: - raise RuntimeError("Cannot activate TLS: not connected") - - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.minimum_version = ssl.TLSVersion.TLSv1_3 - # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; - # set_ciphers() only controls TLS 1.2 and below. - try: - ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") - except ssl.SSLError: - pass - - if tls_cert and tls_key: - ctx.load_cert_chain(tls_cert, tls_key) - - if tls_ca: - ctx.load_verify_locations(tls_ca) - else: - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - - # Upgrade the connection to TLS. - # StreamWriter.start_tls() is the clean API (Python 3.11+). - # For Python 3.10, fall back to loop.start_tls() with the existing protocol. - if hasattr(self._writer, "start_tls"): - await self._writer.start_tls(ctx, server_hostname=self._host) - else: - transport = self._writer.transport - protocol = transport.get_protocol() - loop = asyncio.get_event_loop() - new_transport = await loop.start_tls(transport, protocol, ctx, server_hostname=self._host) - # Update writer's internal transport reference - self._writer._transport = new_transport - - self._tls_active = True - - # Extract OMS exporter secret for legitimation key derivation - ssl_object = self._writer.transport.get_extra_info("ssl_object") - if ssl_object is not None: - try: - self._oms_secret = ssl_object.export_keying_material("EXPERIMENTAL_OMS", 32, None) - logger.debug("OMS exporter secret extracted from TLS session") - except (AttributeError, ssl.SSLError) as e: - logger.warning(f"Could not extract OMS exporter secret: {e}") - self._oms_secret = None - - logger.info("TLS activated on async COTP connection") - async def _create_session(self) -> None: """Send CreateObject to establish S7CommPlus session.""" seq_num = self._next_sequence_number() @@ -685,7 +455,7 @@ async def _create_session(self) -> None: request += bytes([ElementID.TERMINATING_OBJECT]) request += struct.pack(">I", 0) - # Frame header + trailer (always V1 for CreateObject) + # Frame header + trailer frame = encode_header(ProtocolVersion.V1, len(request)) + request frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) await self._send_cotp_dt(frame) @@ -700,178 +470,6 @@ async def _create_session(self) -> None: self._session_id = struct.unpack_from(">I", response, 9)[0] self._protocol_version = version - logger.debug(f"Session created: id=0x{self._session_id:08X}, version=V{version}") - - # Parse response payload to extract ServerSessionVersion - self._parse_create_object_response(response[14:]) - - def _parse_create_object_response(self, payload: bytes) -> None: - """Parse CreateObject response to extract ServerSessionVersion.""" - offset = 0 - while offset < len(payload): - tag = payload[offset] - - if tag == ElementID.ATTRIBUTE: - offset += 1 - if offset >= len(payload): - break - attr_id, consumed = decode_uint32_vlq(payload, offset) - offset += consumed - - if attr_id == ObjectId.SERVER_SESSION_VERSION: - if offset + 2 > len(payload): - break - _flags = payload[offset] - datatype = payload[offset + 1] - offset += 2 - if datatype in (DataType.UDINT, DataType.DWORD): - value, consumed = decode_uint32_vlq(payload, offset) - offset += consumed - self._server_session_version = value - logger.info(f"ServerSessionVersion = {value}") - return - else: - logger.debug(f"ServerSessionVersion has unexpected type {datatype:#04x}") - else: - if offset + 2 > len(payload): - break - _flags = payload[offset] - datatype = payload[offset + 1] - offset += 2 - offset = self._skip_typed_value(payload, offset, datatype, _flags) - - elif tag == ElementID.START_OF_OBJECT: - offset += 1 - if offset + 4 > len(payload): - break - offset += 4 # RelationId - _, consumed = decode_uint32_vlq(payload, offset) - offset += consumed # ClassId - _, consumed = decode_uint32_vlq(payload, offset) - offset += consumed # ClassFlags - _, consumed = decode_uint32_vlq(payload, offset) - offset += consumed # AttributeId - - elif tag == ElementID.TERMINATING_OBJECT: - offset += 1 - elif tag == 0x00: - offset += 1 - else: - offset += 1 - - logger.debug("ServerSessionVersion not found in CreateObject response") - - def _skip_typed_value(self, data: bytes, offset: int, datatype: int, flags: int) -> int: - """Skip over a typed value in the PObject tree.""" - is_array = bool(flags & 0x10) - - if is_array: - if offset >= len(data): - return offset - count, consumed = decode_uint32_vlq(data, offset) - offset += consumed - elem_size = _element_size(datatype) - if elem_size > 0: - offset += count * elem_size - else: - for _ in range(count): - if offset >= len(data): - break - _, consumed = decode_uint32_vlq(data, offset) - offset += consumed - return offset - - if datatype == DataType.NULL: - return offset - elif datatype in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): - return offset + 1 - elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): - return offset + 2 - elif datatype in (DataType.UDINT, DataType.DWORD, DataType.AID, DataType.DINT): - _, consumed = decode_uint32_vlq(data, offset) - return offset + consumed - elif datatype in (DataType.ULINT, DataType.LWORD, DataType.LINT): - _, consumed = decode_uint64_vlq(data, offset) - return offset + consumed - elif datatype == DataType.REAL: - return offset + 4 - elif datatype == DataType.LREAL: - return offset + 8 - elif datatype == DataType.TIMESTAMP: - return offset + 8 - elif datatype == DataType.TIMESPAN: - _, consumed = decode_uint64_vlq(data, offset) - return offset + consumed - elif datatype == DataType.RID: - return offset + 4 - elif datatype in (DataType.BLOB, DataType.WSTRING): - length, consumed = decode_uint32_vlq(data, offset) - return offset + consumed + length - elif datatype == DataType.STRUCT: - count, consumed = decode_uint32_vlq(data, offset) - offset += consumed - for _ in range(count): - if offset + 2 > len(data): - break - sub_flags = data[offset] - sub_type = data[offset + 1] - offset += 2 - offset = self._skip_typed_value(data, offset, sub_type, sub_flags) - return offset - else: - return offset - - async def _setup_session(self) -> None: - """Send SetMultiVariables to echo ServerSessionVersion back to the PLC.""" - if self._server_session_version is None: - return - - seq_num = self._next_sequence_number() - - request = struct.pack( - ">BHHHHIB", - Opcode.REQUEST, - 0x0000, - FunctionCode.SET_MULTI_VARIABLES, - 0x0000, - seq_num, - self._session_id, - 0x36, - ) - - payload = bytearray() - payload += struct.pack(">I", self._session_id) - payload += encode_uint32_vlq(1) - payload += encode_uint32_vlq(1) - payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) - payload += encode_uint32_vlq(1) - payload += bytes([0x00, DataType.UDINT]) - payload += encode_uint32_vlq(self._server_session_version) - payload += bytes([0x00]) - payload += encode_object_qualifier() - payload += struct.pack(">I", 0) - - request += bytes(payload) - - frame = encode_header(self._protocol_version, len(request)) + request - frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) - await self._send_cotp_dt(frame) - - response_data = await self._recv_cotp_dt() - version, data_length, consumed = decode_header(response_data) - response = response_data[consumed : consumed + data_length] - - if len(response) < 14: - raise RuntimeError("SetupSession response too short") - - resp_payload = response[14:] - if len(resp_payload) >= 1: - return_value, _ = decode_uint64_vlq(resp_payload, 0) - if return_value != 0: - logger.warning(f"SetupSession: PLC returned error {return_value}") - else: - logger.info("Session setup completed successfully") - async def _delete_session(self) -> None: """Send DeleteObject to close the session.""" seq_num = self._next_sequence_number() diff --git a/snap7/s7commplus/client.py b/snap7/s7commplus/client.py index 341c879f..44112c9c 100644 --- a/snap7/s7commplus/client.py +++ b/snap7/s7commplus/client.py @@ -14,7 +14,7 @@ the client transparently falls back to the legacy S7 protocol for data block read/write operations. -Status: V1, V2, and V3 (including TLS) connections are functional. +Status: V1 and V2 connections are functional. V3/TLS authentication planned. Reference: thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) """ @@ -24,7 +24,7 @@ from typing import Any, Optional from .connection import S7CommPlusConnection -from .protocol import FunctionCode, Ids, ProtocolVersion +from .protocol import FunctionCode, Ids from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .codec import ( encode_item_address, @@ -148,12 +148,10 @@ def connect( logger.info("Performing PLC legitimation (password authentication)") self._connection.authenticate(password) - # Probe S7CommPlus data operations with a minimal request. - # Skip probe for V2+ with TLS: TLS handshake confirms S7CommPlus support. - if self._connection.protocol_version < ProtocolVersion.V2: - if not self._probe_s7commplus_data(): - logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") - self._setup_legacy_fallback() + # Probe S7CommPlus data operations with a minimal request + if not self._probe_s7commplus_data(): + logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") + self._setup_legacy_fallback() def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations. diff --git a/snap7/s7commplus/connection.py b/snap7/s7commplus/connection.py index ae44506d..a60b44a0 100644 --- a/snap7/s7commplus/connection.py +++ b/snap7/s7commplus/connection.py @@ -85,7 +85,8 @@ class S7CommPlusConnection: - Version-appropriate authentication (V1/V2/V3/TLS) - Frame send/receive (TLS-encrypted when using V17+ firmware) - Supports V1, V2, and V3 (including TLS) authentication. + Currently implements V1 authentication. V2/V3/TLS authentication + layers are planned for future development. """ def __init__( @@ -201,19 +202,21 @@ def connect( logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") # Step 6: Version-specific post-setup - if self._protocol_version >= ProtocolVersion.V2: + if self._protocol_version >= ProtocolVersion.V3: + if not use_tls: + logger.warning( + "PLC reports V3 protocol but TLS is not enabled. Connection may not work without use_tls=True." + ) + elif self._protocol_version == ProtocolVersion.V2: if not self._tls_active: from ..error import S7ConnectionError - raise S7ConnectionError( - f"PLC reports V{self._protocol_version} protocol but TLS is not active. " - "V2/V3 requires TLS. Use use_tls=True." - ) + raise S7ConnectionError("PLC reports V2 protocol but TLS is not active. V2 requires TLS. Use use_tls=True.") # Enable IntegrityId tracking for V2+ self._with_integrity_id = True self._integrity_id_read = 0 self._integrity_id_write = 0 - logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") + logger.info("V2 IntegrityId tracking enabled") # V1: No further authentication needed after CreateObject self._connected = True @@ -248,7 +251,7 @@ def authenticate(self, password: str, username: str = "") -> None: raise S7ConnectionError("Not connected") - if not self._tls_active: + if not self._tls_active or self._oms_secret is None: from ..error import S7ConnectionError raise S7ConnectionError("Legitimation requires TLS. Connect with use_tls=True.") @@ -260,11 +263,11 @@ def authenticate(self, password: str, username: str = "") -> None: # Step 2: Build response (auto-detect legacy vs new) from .legitimation import build_legacy_response, build_new_response - if username and self._oms_secret is not None: + if username: # New-style auth with username always uses AES-256-CBC response_data = build_new_response(password, challenge, self._oms_secret, username) self._send_legitimation_new(response_data) - elif self._oms_secret is not None: + else: # Try new-style first, fall back to legacy SHA-1 XOR try: response_data = build_new_response(password, challenge, self._oms_secret, "") @@ -273,12 +276,6 @@ def authenticate(self, password: str, username: str = "") -> None: # cryptography package not available, use legacy response_data = build_legacy_response(password, challenge) self._send_legitimation_legacy(response_data) - else: - # No OMS secret available (export_keying_material not supported), - # fall back to legacy SHA-1 XOR authentication - logger.info("OMS secret not available, using legacy legitimation") - response_data = build_legacy_response(password, challenge) - self._send_legitimation_legacy(response_data) logger.info("PLC legitimation completed successfully") @@ -1016,13 +1013,7 @@ def _setup_ssl_context( """ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.minimum_version = ssl.TLSVersion.TLSv1_3 - # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; - # set_ciphers() only controls TLS 1.2 and below. We try to set - # preferred ciphers but ignore failures (e.g. OpenSSL 3.x). - try: - ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") - except ssl.SSLError: - pass + ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") if cert_path and key_path: ctx.load_cert_chain(cert_path, key_path) diff --git a/snap7/s7commplus/server.py b/snap7/s7commplus/server.py index 6719bf71..2af0d769 100644 --- a/snap7/s7commplus/server.py +++ b/snap7/s7commplus/server.py @@ -10,7 +10,7 @@ - Internal PLC memory model with thread-safe access - V2 protocol emulation with TLS and IntegrityId tracking -Supports V1 (no TLS), V2 (TLS + IntegrityId), and V3 (TLS + IntegrityId) emulation. +Supports both V1 (no TLS) and V2 (TLS + IntegrityId) emulation. Usage:: @@ -186,21 +186,17 @@ class S7CommPlusServer: Emulates an S7-1200/1500 PLC with: - Internal data block storage with named variables - - S7CommPlus protocol handling (V1, V2, V3) - - V2/V3 TLS support with IntegrityId tracking - - Legitimation (password authentication) emulation + - S7CommPlus protocol handling (V1 and V2) + - V2 TLS support with IntegrityId tracking - Multi-client support (threaded) - CPU state management """ - def __init__(self, protocol_version: int = ProtocolVersion.V1, password: Optional[str] = None) -> None: + def __init__(self, protocol_version: int = ProtocolVersion.V1) -> None: self._data_blocks: dict[int, DataBlock] = {} self._cpu_state = CPUState.RUN self._protocol_version = protocol_version - self._password = password self._next_session_id = 1 - # Per-client authentication state (session_id -> authenticated) - self._authenticated_sessions: dict[int, bool] = {} self._server_socket: Optional[socket.socket] = None self._server_thread: Optional[threading.Thread] = None @@ -209,7 +205,7 @@ def __init__(self, protocol_version: int = ProtocolVersion.V1, password: Optiona self._lock = threading.Lock() self._event_callback: Optional[Callable[..., None]] = None - # TLS configuration (V2/V3) + # TLS configuration (V2) self._ssl_context: Optional[ssl.SSLContext] = None self._use_tls: bool = False @@ -368,12 +364,8 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - session_id = 0 tls_activated = False # Per-client IntegrityId tracking (V2+) - # IntegrityId tracking starts AFTER the session setup (SetMultiVariables - # echoing ServerSessionVersion). The first SetMultiVariables after - # CreateObject is the session setup and doesn't include IntegrityId. integrity_id_read = 0 integrity_id_write = 0 - integrity_id_active = False while self._running: try: @@ -383,9 +375,7 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - break # Process the S7CommPlus request - response = self._process_request( - data, session_id, integrity_id_read, integrity_id_write, integrity_id_active - ) + response = self._process_request(data, session_id, integrity_id_read, integrity_id_write) if response is not None: # Check if session ID was assigned @@ -422,12 +412,7 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - payload = data[hdr_consumed:] if len(payload) >= 14: func_code = struct.unpack_from(">H", payload, 3)[0] - if not integrity_id_active: - # First SetMultiVariables after CreateObject is session setup - # (no IntegrityId). Activate tracking after it. - if func_code == FunctionCode.SET_MULTI_VARIABLES: - integrity_id_active = True - elif func_code in READ_FUNCTION_CODES: + if func_code in READ_FUNCTION_CODES: integrity_id_read = (integrity_id_read + 1) & 0xFFFFFFFF elif func_code not in ( FunctionCode.INIT_SSL, @@ -537,7 +522,6 @@ def _process_request( session_id: int, integrity_id_read: int = 0, integrity_id_write: int = 0, - integrity_id_active: bool = False, ) -> Optional[bytes]: """Process an S7CommPlus request and return a response.""" if len(data) < 4: @@ -563,14 +547,11 @@ def _process_request( seq_num = struct.unpack_from(">H", payload, 7)[0] req_session_id = struct.unpack_from(">I", payload, 9)[0] - # For V2+, skip IntegrityId after the 14-byte header. - # IntegrityId is only present after session setup is complete - # (integrity_id_active=True). The first SetMultiVariables after - # CreateObject is the session setup and doesn't include IntegrityId. + # For V2+, skip IntegrityId after the 14-byte header request_offset = 14 if ( - integrity_id_active - and self._protocol_version >= ProtocolVersion.V2 + self._protocol_version >= ProtocolVersion.V2 + and session_id != 0 and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) ): if request_offset < len(payload): @@ -585,41 +566,14 @@ def _process_request( return self._handle_create_object(seq_num, request_data) elif function_code == FunctionCode.DELETE_OBJECT: return self._handle_delete_object(seq_num, req_session_id) - elif function_code == FunctionCode.GET_VAR_SUBSTREAMED: - response = self._handle_get_var_substreamed(seq_num, req_session_id, request_data) - elif function_code == FunctionCode.SET_VARIABLE: - response = self._handle_set_variable(seq_num, req_session_id, request_data) elif function_code == FunctionCode.EXPLORE: - if not self._check_authenticated(req_session_id): - response = self._build_error_response(seq_num, req_session_id, function_code) - else: - response = self._handle_explore(seq_num, req_session_id, request_data) + return self._handle_explore(seq_num, req_session_id, request_data) elif function_code == FunctionCode.GET_MULTI_VARIABLES: - if not self._check_authenticated(req_session_id): - response = self._build_error_response(seq_num, req_session_id, function_code) - else: - response = self._handle_get_multi_variables(seq_num, req_session_id, request_data) + return self._handle_get_multi_variables(seq_num, req_session_id, request_data) elif function_code == FunctionCode.SET_MULTI_VARIABLES: - # Auth check is inside the handler: session setup must bypass auth - response = self._handle_set_multi_variables( - seq_num, req_session_id, request_data, self._check_authenticated(req_session_id) - ) + return self._handle_set_multi_variables(seq_num, req_session_id, request_data) else: - response = self._build_error_response(seq_num, req_session_id, function_code) - - # For V2+, insert IntegrityId right after the 14-byte response header. - # The client expects IntegrityId at offset 14 in the response, mirroring - # the request format. Only insert when IntegrityId tracking is active. - if ( - integrity_id_active - and self._protocol_version >= ProtocolVersion.V2 - and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) - and response is not None - and len(response) >= 14 - ): - response = response[:14] + encode_uint32_vlq(0) + response[14:] - - return response + return self._build_error_response(seq_num, req_session_id, function_code) def _build_response_header( self, @@ -844,30 +798,16 @@ def _handle_get_multi_variables(self, seq_num: int, session_id: int, request_dat # Terminate error list response += encode_uint32_vlq(0) + # IntegrityId + response += encode_uint32_vlq(0) + return bytes(response) - def _handle_set_multi_variables( - self, seq_num: int, session_id: int, request_data: bytes, is_authenticated: bool = True - ) -> bytes: + def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: """Handle SetMultiVariables -- write variables to data blocks. - Also handles session setup (SetMultiVariables echoing ServerSessionVersion). - Session setup is detected by InObjectId matching the session_id. - Session setup always bypasses auth; data writes require authentication. - Reference: thomas-v2/S7CommPlusDriver/Core/SetMultiVariablesRequest.cs """ - # Check if this is a session setup write (InObjectId = session_id) - if len(request_data) >= 4: - in_object_id = struct.unpack_from(">I", request_data, 0)[0] - if in_object_id == session_id and session_id != 0: - # Session setup: just acknowledge success (no auth required) - return self._build_set_multi_response(seq_num, session_id, []) - - # For data writes, require authentication - if not is_authenticated: - return self._build_error_response(seq_num, session_id, FunctionCode.SET_MULTI_VARIABLES) - response = bytearray() response += struct.pack( ">BHHHHIB", @@ -903,98 +843,8 @@ def _handle_set_multi_variables( # Terminate error list response += encode_uint32_vlq(0) - return bytes(response) - - def _build_set_multi_response( - self, seq_num: int, session_id: int, errors: list[tuple[int, int]] - ) -> bytes: - """Build a SetMultiVariables response.""" - response = bytearray() - response += struct.pack( - ">BHHHHIB", - Opcode.RESPONSE, - 0x0000, - FunctionCode.SET_MULTI_VARIABLES, - 0x0000, - seq_num, - session_id, - 0x00, - ) - response += encode_uint64_vlq(0) # ReturnValue: success - for err_item, err_code in errors: - response += encode_uint32_vlq(err_item) - response += encode_uint64_vlq(err_code) - response += encode_uint32_vlq(0) # Terminate error list - return bytes(response) - - def _check_authenticated(self, session_id: int) -> bool: - """Check if a session is authenticated (or no password is required).""" - if self._password is None: - return True - return self._authenticated_sessions.get(session_id, False) - - def _handle_get_var_substreamed(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: - """Handle GetVarSubStreamed -- used for legitimation challenge request. - - Returns a 20-byte random challenge when the client requests - ServerSessionRequest (address 303). - """ - import os - - response = bytearray() - response += struct.pack( - ">BHHHHIB", - Opcode.RESPONSE, - 0x0000, - FunctionCode.GET_VAR_SUBSTREAMED, - 0x0000, - seq_num, - session_id, - 0x00, - ) - - # ReturnValue: success - response += encode_uint64_vlq(0) - - # Value: 20-byte challenge as BLOB - challenge = os.urandom(20) - response += bytes([0x00, DataType.BLOB]) - response += encode_uint32_vlq(len(challenge)) - response += challenge - - # Trailing padding - response += struct.pack(">I", 0) - - return bytes(response) - - def _handle_set_variable(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: - """Handle SetVariable -- used for legitimation response. - - Accepts the client's authentication response and marks the session - as authenticated. In this emulator, any response is accepted (we - don't verify the actual crypto). - """ - # Mark session as authenticated - self._authenticated_sessions[session_id] = True - logger.debug(f"Session {session_id} authenticated via SetVariable") - - response = bytearray() - response += struct.pack( - ">BHHHHIB", - Opcode.RESPONSE, - 0x0000, - FunctionCode.SET_VARIABLE, - 0x0000, - seq_num, - session_id, - 0x00, - ) - - # ReturnValue: success - response += encode_uint64_vlq(0) - - # Trailing padding - response += struct.pack(">I", 0) + # IntegrityId + response += encode_uint32_vlq(0) return bytes(response) diff --git a/tests/test_s7commplus_tls.py b/tests/test_s7commplus_tls.py deleted file mode 100644 index 873d2a9f..00000000 --- a/tests/test_s7commplus_tls.py +++ /dev/null @@ -1,417 +0,0 @@ -"""Integration tests for S7CommPlus V2/V3 with TLS. - -Tests the complete TLS connection flow including: -- V2 server + client with TLS and IntegrityId tracking -- V3 server + client with TLS -- Legitimation (password authentication) flow -- Async client with TLS -- Error handling for missing TLS - -Requires the `cryptography` package for self-signed certificate generation. -""" - -import struct -import tempfile -import time -from collections.abc import Generator -from pathlib import Path - -import pytest - -from snap7.s7commplus.server import S7CommPlusServer -from snap7.s7commplus.client import S7CommPlusClient -from snap7.s7commplus.protocol import ProtocolVersion - -try: - import ipaddress - - from cryptography import x509 - from cryptography.x509.oid import NameOID - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import rsa - import datetime - - _has_cryptography = True -except ImportError: - _has_cryptography = False - -pytestmark = pytest.mark.skipif(not _has_cryptography, reason="requires cryptography package") - -# Use high ports to avoid conflicts -V2_TEST_PORT = 11130 -V3_TEST_PORT = 11131 -V2_AUTH_PORT = 11132 -V3_AUTH_PORT = 11133 - - -@pytest.fixture(scope="module") -def tls_certs() -> Generator[dict[str, str], None, None]: - """Generate self-signed TLS certificates for testing.""" - with tempfile.TemporaryDirectory() as tmpdir: - key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), - ]) - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) - .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) - .add_extension( - x509.SubjectAlternativeName([ - x509.DNSName("localhost"), - x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), - ]), - critical=False, - ) - .sign(key, hashes.SHA256()) - ) - - cert_path = str(Path(tmpdir) / "server.crt") - key_path = str(Path(tmpdir) / "server.key") - - with open(cert_path, "wb") as f: - f.write(cert.public_bytes(serialization.Encoding.PEM)) - - with open(key_path, "wb") as f: - f.write( - key.private_bytes( - serialization.Encoding.PEM, - serialization.PrivateFormat.TraditionalOpenSSL, - serialization.NoEncryption(), - ) - ) - - yield {"cert": cert_path, "key": key_path} - - -def _make_server( - protocol_version: int, - password: str | None = None, -) -> S7CommPlusServer: - """Create and configure an S7CommPlus server with test data blocks.""" - srv = S7CommPlusServer(protocol_version=protocol_version, password=password) - - srv.register_db( - 1, - { - "temperature": ("Real", 0), - "pressure": ("Real", 4), - "running": ("Bool", 8), - "count": ("DInt", 10), - }, - ) - srv.register_raw_db(2, bytearray(256)) - - # Pre-populate DB1 - db1 = srv.get_db(1) - assert db1 is not None - struct.pack_into(">f", db1.data, 0, 23.5) - struct.pack_into(">f", db1.data, 4, 1.013) - db1.data[8] = 1 - struct.pack_into(">i", db1.data, 10, 42) - - return srv - - -@pytest.fixture() -def v2_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: - """V2 server with TLS.""" - srv = _make_server(ProtocolVersion.V2) - srv.start(port=V2_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) - time.sleep(0.1) - yield srv - srv.stop() - - -@pytest.fixture() -def v3_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: - """V3 server with TLS.""" - srv = _make_server(ProtocolVersion.V3) - srv.start(port=V3_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) - time.sleep(0.1) - yield srv - srv.stop() - - -@pytest.fixture() -def v2_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: - """V2 server with TLS and password authentication.""" - srv = _make_server(ProtocolVersion.V2, password="secret123") - srv.start(port=V2_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) - time.sleep(0.1) - yield srv - srv.stop() - - -@pytest.fixture() -def v3_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: - """V3 server with TLS and password authentication.""" - srv = _make_server(ProtocolVersion.V3, password="secret123") - srv.start(port=V3_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) - time.sleep(0.1) - yield srv - srv.stop() - - -class TestV2TLS: - """Test V2 protocol with TLS.""" - - def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - assert client.connected - assert client.session_id != 0 - assert client.protocol_version == ProtocolVersion.V2 - client.disconnect() - assert not client.connected - - def test_read_real(self, v2_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - try: - data = client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 23.5) < 0.001 - finally: - client.disconnect() - - def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - try: - client.db_write(1, 0, struct.pack(">f", 99.9)) - data = client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 99.9) < 0.1 - finally: - client.disconnect() - - def test_multi_read(self, v2_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - try: - results = client.db_read_multi([ - (1, 0, 4), - (1, 4, 4), - (2, 0, 4), - ]) - assert len(results) == 3 - temp = struct.unpack(">f", results[0])[0] - assert abs(temp - 23.5) < 0.001 - finally: - client.disconnect() - - def test_explore(self, v2_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - try: - response = client.explore() - assert len(response) > 0 - finally: - client.disconnect() - - -class TestV3TLS: - """Test V3 protocol with TLS.""" - - def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - assert client.connected - assert client.session_id != 0 - assert client.protocol_version == ProtocolVersion.V3 - client.disconnect() - assert not client.connected - - def test_read_real(self, v3_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - try: - data = client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 23.5) < 0.001 - finally: - client.disconnect() - - def test_write_and_read_back(self, v3_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - try: - client.db_write(1, 0, struct.pack(">f", 88.8)) - data = client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 88.8) < 0.1 - finally: - client.disconnect() - - def test_multi_read(self, v3_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - try: - results = client.db_read_multi([ - (1, 0, 4), - (1, 10, 4), - ]) - assert len(results) == 2 - finally: - client.disconnect() - - def test_data_persists_across_clients(self, v3_server: S7CommPlusServer) -> None: - c1 = S7CommPlusClient() - c1.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - c1.db_write(2, 0, b"\xca\xfe\xba\xbe") - c1.disconnect() - - c2 = S7CommPlusClient() - c2.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - data = c2.db_read(2, 0, 4) - c2.disconnect() - - assert data == b"\xca\xfe\xba\xbe" - - -class TestV2WithoutTLS: - """Test that V2/V3 connections fail without TLS.""" - - def test_v2_server_requires_tls_on_client(self, tls_certs: dict[str, str]) -> None: - """V2 server reports V2 version; client should raise if no TLS.""" - # Start a V2 server WITHOUT TLS (so client can connect but gets V2 version) - srv = _make_server(ProtocolVersion.V2) - srv.start(port=V2_TEST_PORT + 10) - time.sleep(0.1) - try: - client = S7CommPlusClient() - with pytest.raises(Exception): - # The server reports V2 but client didn't use TLS - client.connect("127.0.0.1", port=V2_TEST_PORT + 10) - finally: - srv.stop() - - -class TestLegitimation: - """Test password authentication (legitimation).""" - - def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") - assert client.connected - try: - # Should be able to read after authentication - data = client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 23.5) < 0.001 - finally: - client.disconnect() - - def test_v3_connect_with_password(self, v3_auth_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V3_AUTH_PORT, use_tls=True, password="secret123") - assert client.connected - try: - data = client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 23.5) < 0.001 - finally: - client.disconnect() - - def test_v2_without_password_on_protected_server(self, v2_auth_server: S7CommPlusServer) -> None: - """Connecting without password to a password-protected server should fail data ops.""" - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True) - try: - # Server should reject data operations without authentication - with pytest.raises(Exception): - client.db_read(1, 0, 4) - finally: - client.disconnect() - - def test_v2_write_with_password(self, v2_auth_server: S7CommPlusServer) -> None: - client = S7CommPlusClient() - client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") - try: - client.db_write(1, 0, struct.pack(">f", 55.5)) - data = client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 55.5) < 0.1 - finally: - client.disconnect() - - -@pytest.mark.asyncio -class TestAsyncV2TLS: - """Test async client with V2 TLS.""" - - async def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: - from snap7.s7commplus.async_client import S7CommPlusAsyncClient - - client = S7CommPlusAsyncClient() - await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - assert client.connected - assert client.session_id != 0 - assert client.protocol_version == ProtocolVersion.V2 - assert client.tls_active - await client.disconnect() - assert not client.connected - - async def test_read_real(self, v2_server: S7CommPlusServer) -> None: - from snap7.s7commplus.async_client import S7CommPlusAsyncClient - - async with S7CommPlusAsyncClient() as client: - await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - data = await client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 23.5) < 0.001 - - async def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: - from snap7.s7commplus.async_client import S7CommPlusAsyncClient - - async with S7CommPlusAsyncClient() as client: - await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) - await client.db_write(1, 0, struct.pack(">f", 77.7)) - data = await client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 77.7) < 0.1 - - -@pytest.mark.asyncio -class TestAsyncV3TLS: - """Test async client with V3 TLS.""" - - async def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: - from snap7.s7commplus.async_client import S7CommPlusAsyncClient - - client = S7CommPlusAsyncClient() - await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - assert client.connected - assert client.protocol_version == ProtocolVersion.V3 - assert client.tls_active - await client.disconnect() - - async def test_read_write(self, v3_server: S7CommPlusServer) -> None: - from snap7.s7commplus.async_client import S7CommPlusAsyncClient - - async with S7CommPlusAsyncClient() as client: - await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) - await client.db_write(2, 0, b"\xde\xad") - data = await client.db_read(2, 0, 2) - assert data == b"\xde\xad" - - -@pytest.mark.asyncio -class TestAsyncLegitimation: - """Test async client with password authentication.""" - - async def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: - from snap7.s7commplus.async_client import S7CommPlusAsyncClient - - async with S7CommPlusAsyncClient() as client: - await client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") - assert client.connected - data = await client.db_read(1, 0, 4) - value = struct.unpack(">f", data)[0] - assert abs(value - 23.5) < 0.001 From 02625ae67b18db15768dcea25f0910d49a890ecc Mon Sep 17 00:00:00 2001 From: Gijs Molenaar Date: Fri, 27 Mar 2026 08:51:39 +0200 Subject: [PATCH 4/4] Reapply "Implement S7CommPlus V2/V3 protocol with TLS and IntegrityId" (#659) This reverts commit bf2f828a629e61c74e7125feb024f3b781971e96. --- snap7/s7commplus/async_client.py | 408 +++++++++++++++++++++++++++++- snap7/s7commplus/client.py | 14 +- snap7/s7commplus/connection.py | 37 +-- snap7/s7commplus/server.py | 190 ++++++++++++-- tests/test_s7commplus_tls.py | 417 +++++++++++++++++++++++++++++++ 5 files changed, 1023 insertions(+), 43 deletions(-) create mode 100644 tests/test_s7commplus_tls.py diff --git a/snap7/s7commplus/async_client.py b/snap7/s7commplus/async_client.py index e6a46fe2..e58e9fc4 100644 --- a/snap7/s7commplus/async_client.py +++ b/snap7/s7commplus/async_client.py @@ -4,6 +4,10 @@ Provides the same API as S7CommPlusClient but using asyncio for non-blocking I/O. Uses asyncio.Lock for concurrent safety. +Supports all S7CommPlus protocol versions (V1/V2/V3/TLS). The protocol +version is auto-detected from the PLC's CreateObject response during +connection setup. + When a PLC does not support S7CommPlus data operations, the client transparently falls back to the legacy S7 protocol for data block read/write operations (using synchronous calls in an executor). @@ -18,6 +22,7 @@ import asyncio import logging +import ssl import struct from typing import Any, Optional @@ -25,6 +30,7 @@ DataType, ElementID, FunctionCode, + LegitimationId, ObjectId, Opcode, ProtocolVersion, @@ -35,6 +41,7 @@ from .codec import encode_header, decode_header, encode_typed_value, encode_object_qualifier from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .client import _build_read_payload, _parse_read_response, _build_write_payload, _parse_write_response +from .connection import _element_size logger = logging.getLogger(__name__) @@ -47,7 +54,7 @@ class S7CommPlusAsyncClient: """Async S7CommPlus client for S7-1200/1500 PLCs. - Supports V1 and V2 protocols. V3/TLS planned for future. + Supports V1, V2, and V3 protocols (including TLS). Uses asyncio for all I/O operations and asyncio.Lock for concurrent safety when shared between multiple coroutines. @@ -76,6 +83,13 @@ def __init__(self) -> None: self._integrity_id_write: int = 0 self._with_integrity_id: bool = False + # TLS state + self._tls_active: bool = False + self._oms_secret: Optional[bytes] = None + + # Session setup + self._server_session_version: Optional[int] = None + @property def connected(self) -> bool: if self._use_legacy_data and self._legacy_client is not None: @@ -95,12 +109,22 @@ def using_legacy_fallback(self) -> bool: """Whether the client is using legacy S7 protocol for data operations.""" return self._use_legacy_data + @property + def tls_active(self) -> bool: + """Whether TLS encryption is active on this connection.""" + return self._tls_active + async def connect( self, host: str, port: int = 102, rack: int = 0, slot: int = 1, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + password: Optional[str] = None, ) -> None: """Connect to an S7-1200/1500 PLC. @@ -112,6 +136,11 @@ async def connect( port: TCP port (default 102) rack: PLC rack number slot: PLC slot number + use_tls: Whether to activate TLS (required for V2/V3) + tls_cert: Path to client TLS certificate (PEM) + tls_key: Path to client private key (PEM) + tls_ca: Path to CA certificate for PLC verification (PEM) + password: PLC password for legitimation (V2+ with TLS) """ self._host = host self._port = port @@ -128,14 +157,43 @@ async def connect( # InitSSL handshake await self._init_ssl() + # TLS activation (between InitSSL and CreateObject) + if use_tls: + await self._activate_tls(tls_cert=tls_cert, tls_key=tls_key, tls_ca=tls_ca) + # S7CommPlus session setup await self._create_session() + # Echo ServerSessionVersion back to complete handshake + if self._server_session_version is not None: + await self._setup_session() + else: + logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") + + # Version-specific post-setup + if self._protocol_version >= ProtocolVersion.V2: + if not self._tls_active: + raise RuntimeError( + f"PLC reports V{self._protocol_version} protocol but TLS is not active. " + "V2/V3 requires TLS. Use use_tls=True." + ) + self._with_integrity_id = True + self._integrity_id_read = 0 + self._integrity_id_write = 0 + logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") + self._connected = True logger.info( - f"Async S7CommPlus connected to {host}:{port}, version=V{self._protocol_version}, session={self._session_id}" + f"Async S7CommPlus connected to {host}:{port}, " + f"version=V{self._protocol_version}, session={self._session_id}, " + f"tls={self._tls_active}" ) + # Handle legitimation for password-protected PLCs + if password is not None and self._tls_active: + logger.info("Performing PLC legitimation (password authentication)") + await self.authenticate(password) + # Probe S7CommPlus data operations if not await self._probe_s7commplus_data(): logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") @@ -145,6 +203,116 @@ async def connect( await self.disconnect() raise + async def authenticate(self, password: str, username: str = "") -> None: + """Perform PLC password authentication (legitimation). + + Must be called after connect() and before data operations on + password-protected PLCs. Requires TLS to be active (V2+). + + Args: + password: PLC password + username: Username for new-style auth (optional) + """ + if not self._connected: + raise RuntimeError("Not connected") + + if not self._tls_active: + raise RuntimeError("Legitimation requires TLS. Connect with use_tls=True.") + + # Step 1: Get challenge from PLC + challenge = await self._get_legitimation_challenge() + logger.info(f"Received legitimation challenge ({len(challenge)} bytes)") + + # Step 2: Build response (auto-detect legacy vs new) + from .legitimation import build_legacy_response, build_new_response + + if username and self._oms_secret is not None: + response_data = build_new_response(password, challenge, self._oms_secret, username) + await self._send_legitimation_new(response_data) + elif self._oms_secret is not None: + try: + response_data = build_new_response(password, challenge, self._oms_secret, "") + await self._send_legitimation_new(response_data) + except NotImplementedError: + response_data = build_legacy_response(password, challenge) + await self._send_legitimation_legacy(response_data) + else: + logger.info("OMS secret not available, using legacy legitimation") + response_data = build_legacy_response(password, challenge) + await self._send_legitimation_legacy(response_data) + + logger.info("PLC legitimation completed successfully") + + async def _get_legitimation_challenge(self) -> bytes: + """Request legitimation challenge from PLC.""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_REQUEST) + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.GET_VAR_SUBSTREAMED, bytes(payload)) + + offset = 0 + return_value, consumed = decode_uint64_vlq(resp_payload, offset) + offset += consumed + + if return_value != 0: + raise RuntimeError(f"GetVarSubStreamed for challenge failed: return_value={return_value}") + + if offset + 2 > len(resp_payload): + raise RuntimeError("Challenge response too short") + + _flags = resp_payload[offset] + datatype = resp_payload[offset + 1] + offset += 2 + + if datatype == DataType.BLOB: + length, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + length]) + else: + count, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + count]) + + async def _send_legitimation_new(self, encrypted_response: bytes) -> None: + """Send new-style legitimation response (AES-256-CBC encrypted).""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.LEGITIMATE) + payload += bytes([0x00, DataType.BLOB]) + payload += encode_uint32_vlq(len(encrypted_response)) + payload += encrypted_response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + raise RuntimeError(f"Legitimation rejected by PLC: return_value={return_value}") + + async def _send_legitimation_legacy(self, response: bytes) -> None: + """Send legacy legitimation response (SHA-1 XOR).""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_RESPONSE) + payload += bytes([0x10, DataType.USINT]) + payload += encode_uint32_vlq(len(response)) + payload += response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + raise RuntimeError(f"Legacy legitimation rejected by PLC: return_value={return_value}") + async def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations.""" try: @@ -198,6 +366,9 @@ async def disconnect(self) -> None: self._with_integrity_id = False self._integrity_id_read = 0 self._integrity_id_write = 0 + self._tls_active = False + self._oms_secret = None + self._server_session_version = None if self._writer: try: @@ -407,6 +578,65 @@ async def _init_ssl(self) -> None: logger.debug(f"InitSSL response received, version=V{version}") + async def _activate_tls( + self, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + ) -> None: + """Activate TLS 1.3 over the COTP connection. + + Called after InitSSL and before CreateObject. Wraps the underlying + asyncio streams with TLS. + """ + if self._writer is None: + raise RuntimeError("Cannot activate TLS: not connected") + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; + # set_ciphers() only controls TLS 1.2 and below. + try: + ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + except ssl.SSLError: + pass + + if tls_cert and tls_key: + ctx.load_cert_chain(tls_cert, tls_key) + + if tls_ca: + ctx.load_verify_locations(tls_ca) + else: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + # Upgrade the connection to TLS. + # StreamWriter.start_tls() is the clean API (Python 3.11+). + # For Python 3.10, fall back to loop.start_tls() with the existing protocol. + if hasattr(self._writer, "start_tls"): + await self._writer.start_tls(ctx, server_hostname=self._host) + else: + transport = self._writer.transport + protocol = transport.get_protocol() + loop = asyncio.get_event_loop() + new_transport = await loop.start_tls(transport, protocol, ctx, server_hostname=self._host) + # Update writer's internal transport reference + self._writer._transport = new_transport + + self._tls_active = True + + # Extract OMS exporter secret for legitimation key derivation + ssl_object = self._writer.transport.get_extra_info("ssl_object") + if ssl_object is not None: + try: + self._oms_secret = ssl_object.export_keying_material("EXPERIMENTAL_OMS", 32, None) + logger.debug("OMS exporter secret extracted from TLS session") + except (AttributeError, ssl.SSLError) as e: + logger.warning(f"Could not extract OMS exporter secret: {e}") + self._oms_secret = None + + logger.info("TLS activated on async COTP connection") + async def _create_session(self) -> None: """Send CreateObject to establish S7CommPlus session.""" seq_num = self._next_sequence_number() @@ -455,7 +685,7 @@ async def _create_session(self) -> None: request += bytes([ElementID.TERMINATING_OBJECT]) request += struct.pack(">I", 0) - # Frame header + trailer + # Frame header + trailer (always V1 for CreateObject) frame = encode_header(ProtocolVersion.V1, len(request)) + request frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) await self._send_cotp_dt(frame) @@ -470,6 +700,178 @@ async def _create_session(self) -> None: self._session_id = struct.unpack_from(">I", response, 9)[0] self._protocol_version = version + logger.debug(f"Session created: id=0x{self._session_id:08X}, version=V{version}") + + # Parse response payload to extract ServerSessionVersion + self._parse_create_object_response(response[14:]) + + def _parse_create_object_response(self, payload: bytes) -> None: + """Parse CreateObject response to extract ServerSessionVersion.""" + offset = 0 + while offset < len(payload): + tag = payload[offset] + + if tag == ElementID.ATTRIBUTE: + offset += 1 + if offset >= len(payload): + break + attr_id, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + + if attr_id == ObjectId.SERVER_SESSION_VERSION: + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + if datatype in (DataType.UDINT, DataType.DWORD): + value, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + self._server_session_version = value + logger.info(f"ServerSessionVersion = {value}") + return + else: + logger.debug(f"ServerSessionVersion has unexpected type {datatype:#04x}") + else: + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + offset = self._skip_typed_value(payload, offset, datatype, _flags) + + elif tag == ElementID.START_OF_OBJECT: + offset += 1 + if offset + 4 > len(payload): + break + offset += 4 # RelationId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassFlags + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # AttributeId + + elif tag == ElementID.TERMINATING_OBJECT: + offset += 1 + elif tag == 0x00: + offset += 1 + else: + offset += 1 + + logger.debug("ServerSessionVersion not found in CreateObject response") + + def _skip_typed_value(self, data: bytes, offset: int, datatype: int, flags: int) -> int: + """Skip over a typed value in the PObject tree.""" + is_array = bool(flags & 0x10) + + if is_array: + if offset >= len(data): + return offset + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + elem_size = _element_size(datatype) + if elem_size > 0: + offset += count * elem_size + else: + for _ in range(count): + if offset >= len(data): + break + _, consumed = decode_uint32_vlq(data, offset) + offset += consumed + return offset + + if datatype == DataType.NULL: + return offset + elif datatype in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + return offset + 1 + elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): + return offset + 2 + elif datatype in (DataType.UDINT, DataType.DWORD, DataType.AID, DataType.DINT): + _, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + elif datatype in (DataType.ULINT, DataType.LWORD, DataType.LINT): + _, consumed = decode_uint64_vlq(data, offset) + return offset + consumed + elif datatype == DataType.REAL: + return offset + 4 + elif datatype == DataType.LREAL: + return offset + 8 + elif datatype == DataType.TIMESTAMP: + return offset + 8 + elif datatype == DataType.TIMESPAN: + _, consumed = decode_uint64_vlq(data, offset) + return offset + consumed + elif datatype == DataType.RID: + return offset + 4 + elif datatype in (DataType.BLOB, DataType.WSTRING): + length, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + length + elif datatype == DataType.STRUCT: + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + for _ in range(count): + if offset + 2 > len(data): + break + sub_flags = data[offset] + sub_type = data[offset + 1] + offset += 2 + offset = self._skip_typed_value(data, offset, sub_type, sub_flags) + return offset + else: + return offset + + async def _setup_session(self) -> None: + """Send SetMultiVariables to echo ServerSessionVersion back to the PLC.""" + if self._server_session_version is None: + return + + seq_num = self._next_sequence_number() + + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + self._session_id, + 0x36, + ) + + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) + payload += encode_uint32_vlq(1) + payload += bytes([0x00, DataType.UDINT]) + payload += encode_uint32_vlq(self._server_session_version) + payload += bytes([0x00]) + payload += encode_object_qualifier() + payload += struct.pack(">I", 0) + + request += bytes(payload) + + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + await self._send_cotp_dt(frame) + + response_data = await self._recv_cotp_dt() + version, data_length, consumed = decode_header(response_data) + response = response_data[consumed : consumed + data_length] + + if len(response) < 14: + raise RuntimeError("SetupSession response too short") + + resp_payload = response[14:] + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value != 0: + logger.warning(f"SetupSession: PLC returned error {return_value}") + else: + logger.info("Session setup completed successfully") + async def _delete_session(self) -> None: """Send DeleteObject to close the session.""" seq_num = self._next_sequence_number() diff --git a/snap7/s7commplus/client.py b/snap7/s7commplus/client.py index 44112c9c..341c879f 100644 --- a/snap7/s7commplus/client.py +++ b/snap7/s7commplus/client.py @@ -14,7 +14,7 @@ the client transparently falls back to the legacy S7 protocol for data block read/write operations. -Status: V1 and V2 connections are functional. V3/TLS authentication planned. +Status: V1, V2, and V3 (including TLS) connections are functional. Reference: thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) """ @@ -24,7 +24,7 @@ from typing import Any, Optional from .connection import S7CommPlusConnection -from .protocol import FunctionCode, Ids +from .protocol import FunctionCode, Ids, ProtocolVersion from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .codec import ( encode_item_address, @@ -148,10 +148,12 @@ def connect( logger.info("Performing PLC legitimation (password authentication)") self._connection.authenticate(password) - # Probe S7CommPlus data operations with a minimal request - if not self._probe_s7commplus_data(): - logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") - self._setup_legacy_fallback() + # Probe S7CommPlus data operations with a minimal request. + # Skip probe for V2+ with TLS: TLS handshake confirms S7CommPlus support. + if self._connection.protocol_version < ProtocolVersion.V2: + if not self._probe_s7commplus_data(): + logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") + self._setup_legacy_fallback() def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations. diff --git a/snap7/s7commplus/connection.py b/snap7/s7commplus/connection.py index a60b44a0..ae44506d 100644 --- a/snap7/s7commplus/connection.py +++ b/snap7/s7commplus/connection.py @@ -85,8 +85,7 @@ class S7CommPlusConnection: - Version-appropriate authentication (V1/V2/V3/TLS) - Frame send/receive (TLS-encrypted when using V17+ firmware) - Currently implements V1 authentication. V2/V3/TLS authentication - layers are planned for future development. + Supports V1, V2, and V3 (including TLS) authentication. """ def __init__( @@ -202,21 +201,19 @@ def connect( logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") # Step 6: Version-specific post-setup - if self._protocol_version >= ProtocolVersion.V3: - if not use_tls: - logger.warning( - "PLC reports V3 protocol but TLS is not enabled. Connection may not work without use_tls=True." - ) - elif self._protocol_version == ProtocolVersion.V2: + if self._protocol_version >= ProtocolVersion.V2: if not self._tls_active: from ..error import S7ConnectionError - raise S7ConnectionError("PLC reports V2 protocol but TLS is not active. V2 requires TLS. Use use_tls=True.") + raise S7ConnectionError( + f"PLC reports V{self._protocol_version} protocol but TLS is not active. " + "V2/V3 requires TLS. Use use_tls=True." + ) # Enable IntegrityId tracking for V2+ self._with_integrity_id = True self._integrity_id_read = 0 self._integrity_id_write = 0 - logger.info("V2 IntegrityId tracking enabled") + logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") # V1: No further authentication needed after CreateObject self._connected = True @@ -251,7 +248,7 @@ def authenticate(self, password: str, username: str = "") -> None: raise S7ConnectionError("Not connected") - if not self._tls_active or self._oms_secret is None: + if not self._tls_active: from ..error import S7ConnectionError raise S7ConnectionError("Legitimation requires TLS. Connect with use_tls=True.") @@ -263,11 +260,11 @@ def authenticate(self, password: str, username: str = "") -> None: # Step 2: Build response (auto-detect legacy vs new) from .legitimation import build_legacy_response, build_new_response - if username: + if username and self._oms_secret is not None: # New-style auth with username always uses AES-256-CBC response_data = build_new_response(password, challenge, self._oms_secret, username) self._send_legitimation_new(response_data) - else: + elif self._oms_secret is not None: # Try new-style first, fall back to legacy SHA-1 XOR try: response_data = build_new_response(password, challenge, self._oms_secret, "") @@ -276,6 +273,12 @@ def authenticate(self, password: str, username: str = "") -> None: # cryptography package not available, use legacy response_data = build_legacy_response(password, challenge) self._send_legitimation_legacy(response_data) + else: + # No OMS secret available (export_keying_material not supported), + # fall back to legacy SHA-1 XOR authentication + logger.info("OMS secret not available, using legacy legitimation") + response_data = build_legacy_response(password, challenge) + self._send_legitimation_legacy(response_data) logger.info("PLC legitimation completed successfully") @@ -1013,7 +1016,13 @@ def _setup_ssl_context( """ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.minimum_version = ssl.TLSVersion.TLSv1_3 - ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; + # set_ciphers() only controls TLS 1.2 and below. We try to set + # preferred ciphers but ignore failures (e.g. OpenSSL 3.x). + try: + ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + except ssl.SSLError: + pass if cert_path and key_path: ctx.load_cert_chain(cert_path, key_path) diff --git a/snap7/s7commplus/server.py b/snap7/s7commplus/server.py index 2af0d769..6719bf71 100644 --- a/snap7/s7commplus/server.py +++ b/snap7/s7commplus/server.py @@ -10,7 +10,7 @@ - Internal PLC memory model with thread-safe access - V2 protocol emulation with TLS and IntegrityId tracking -Supports both V1 (no TLS) and V2 (TLS + IntegrityId) emulation. +Supports V1 (no TLS), V2 (TLS + IntegrityId), and V3 (TLS + IntegrityId) emulation. Usage:: @@ -186,17 +186,21 @@ class S7CommPlusServer: Emulates an S7-1200/1500 PLC with: - Internal data block storage with named variables - - S7CommPlus protocol handling (V1 and V2) - - V2 TLS support with IntegrityId tracking + - S7CommPlus protocol handling (V1, V2, V3) + - V2/V3 TLS support with IntegrityId tracking + - Legitimation (password authentication) emulation - Multi-client support (threaded) - CPU state management """ - def __init__(self, protocol_version: int = ProtocolVersion.V1) -> None: + def __init__(self, protocol_version: int = ProtocolVersion.V1, password: Optional[str] = None) -> None: self._data_blocks: dict[int, DataBlock] = {} self._cpu_state = CPUState.RUN self._protocol_version = protocol_version + self._password = password self._next_session_id = 1 + # Per-client authentication state (session_id -> authenticated) + self._authenticated_sessions: dict[int, bool] = {} self._server_socket: Optional[socket.socket] = None self._server_thread: Optional[threading.Thread] = None @@ -205,7 +209,7 @@ def __init__(self, protocol_version: int = ProtocolVersion.V1) -> None: self._lock = threading.Lock() self._event_callback: Optional[Callable[..., None]] = None - # TLS configuration (V2) + # TLS configuration (V2/V3) self._ssl_context: Optional[ssl.SSLContext] = None self._use_tls: bool = False @@ -364,8 +368,12 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - session_id = 0 tls_activated = False # Per-client IntegrityId tracking (V2+) + # IntegrityId tracking starts AFTER the session setup (SetMultiVariables + # echoing ServerSessionVersion). The first SetMultiVariables after + # CreateObject is the session setup and doesn't include IntegrityId. integrity_id_read = 0 integrity_id_write = 0 + integrity_id_active = False while self._running: try: @@ -375,7 +383,9 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - break # Process the S7CommPlus request - response = self._process_request(data, session_id, integrity_id_read, integrity_id_write) + response = self._process_request( + data, session_id, integrity_id_read, integrity_id_write, integrity_id_active + ) if response is not None: # Check if session ID was assigned @@ -412,7 +422,12 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - payload = data[hdr_consumed:] if len(payload) >= 14: func_code = struct.unpack_from(">H", payload, 3)[0] - if func_code in READ_FUNCTION_CODES: + if not integrity_id_active: + # First SetMultiVariables after CreateObject is session setup + # (no IntegrityId). Activate tracking after it. + if func_code == FunctionCode.SET_MULTI_VARIABLES: + integrity_id_active = True + elif func_code in READ_FUNCTION_CODES: integrity_id_read = (integrity_id_read + 1) & 0xFFFFFFFF elif func_code not in ( FunctionCode.INIT_SSL, @@ -522,6 +537,7 @@ def _process_request( session_id: int, integrity_id_read: int = 0, integrity_id_write: int = 0, + integrity_id_active: bool = False, ) -> Optional[bytes]: """Process an S7CommPlus request and return a response.""" if len(data) < 4: @@ -547,11 +563,14 @@ def _process_request( seq_num = struct.unpack_from(">H", payload, 7)[0] req_session_id = struct.unpack_from(">I", payload, 9)[0] - # For V2+, skip IntegrityId after the 14-byte header + # For V2+, skip IntegrityId after the 14-byte header. + # IntegrityId is only present after session setup is complete + # (integrity_id_active=True). The first SetMultiVariables after + # CreateObject is the session setup and doesn't include IntegrityId. request_offset = 14 if ( - self._protocol_version >= ProtocolVersion.V2 - and session_id != 0 + integrity_id_active + and self._protocol_version >= ProtocolVersion.V2 and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) ): if request_offset < len(payload): @@ -566,14 +585,41 @@ def _process_request( return self._handle_create_object(seq_num, request_data) elif function_code == FunctionCode.DELETE_OBJECT: return self._handle_delete_object(seq_num, req_session_id) + elif function_code == FunctionCode.GET_VAR_SUBSTREAMED: + response = self._handle_get_var_substreamed(seq_num, req_session_id, request_data) + elif function_code == FunctionCode.SET_VARIABLE: + response = self._handle_set_variable(seq_num, req_session_id, request_data) elif function_code == FunctionCode.EXPLORE: - return self._handle_explore(seq_num, req_session_id, request_data) + if not self._check_authenticated(req_session_id): + response = self._build_error_response(seq_num, req_session_id, function_code) + else: + response = self._handle_explore(seq_num, req_session_id, request_data) elif function_code == FunctionCode.GET_MULTI_VARIABLES: - return self._handle_get_multi_variables(seq_num, req_session_id, request_data) + if not self._check_authenticated(req_session_id): + response = self._build_error_response(seq_num, req_session_id, function_code) + else: + response = self._handle_get_multi_variables(seq_num, req_session_id, request_data) elif function_code == FunctionCode.SET_MULTI_VARIABLES: - return self._handle_set_multi_variables(seq_num, req_session_id, request_data) + # Auth check is inside the handler: session setup must bypass auth + response = self._handle_set_multi_variables( + seq_num, req_session_id, request_data, self._check_authenticated(req_session_id) + ) else: - return self._build_error_response(seq_num, req_session_id, function_code) + response = self._build_error_response(seq_num, req_session_id, function_code) + + # For V2+, insert IntegrityId right after the 14-byte response header. + # The client expects IntegrityId at offset 14 in the response, mirroring + # the request format. Only insert when IntegrityId tracking is active. + if ( + integrity_id_active + and self._protocol_version >= ProtocolVersion.V2 + and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) + and response is not None + and len(response) >= 14 + ): + response = response[:14] + encode_uint32_vlq(0) + response[14:] + + return response def _build_response_header( self, @@ -798,16 +844,30 @@ def _handle_get_multi_variables(self, seq_num: int, session_id: int, request_dat # Terminate error list response += encode_uint32_vlq(0) - # IntegrityId - response += encode_uint32_vlq(0) - return bytes(response) - def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + def _handle_set_multi_variables( + self, seq_num: int, session_id: int, request_data: bytes, is_authenticated: bool = True + ) -> bytes: """Handle SetMultiVariables -- write variables to data blocks. + Also handles session setup (SetMultiVariables echoing ServerSessionVersion). + Session setup is detected by InObjectId matching the session_id. + Session setup always bypasses auth; data writes require authentication. + Reference: thomas-v2/S7CommPlusDriver/Core/SetMultiVariablesRequest.cs """ + # Check if this is a session setup write (InObjectId = session_id) + if len(request_data) >= 4: + in_object_id = struct.unpack_from(">I", request_data, 0)[0] + if in_object_id == session_id and session_id != 0: + # Session setup: just acknowledge success (no auth required) + return self._build_set_multi_response(seq_num, session_id, []) + + # For data writes, require authentication + if not is_authenticated: + return self._build_error_response(seq_num, session_id, FunctionCode.SET_MULTI_VARIABLES) + response = bytearray() response += struct.pack( ">BHHHHIB", @@ -843,8 +903,98 @@ def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_dat # Terminate error list response += encode_uint32_vlq(0) - # IntegrityId - response += encode_uint32_vlq(0) + return bytes(response) + + def _build_set_multi_response( + self, seq_num: int, session_id: int, errors: list[tuple[int, int]] + ) -> bytes: + """Build a SetMultiVariables response.""" + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + session_id, + 0x00, + ) + response += encode_uint64_vlq(0) # ReturnValue: success + for err_item, err_code in errors: + response += encode_uint32_vlq(err_item) + response += encode_uint64_vlq(err_code) + response += encode_uint32_vlq(0) # Terminate error list + return bytes(response) + + def _check_authenticated(self, session_id: int) -> bool: + """Check if a session is authenticated (or no password is required).""" + if self._password is None: + return True + return self._authenticated_sessions.get(session_id, False) + + def _handle_get_var_substreamed(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle GetVarSubStreamed -- used for legitimation challenge request. + + Returns a 20-byte random challenge when the client requests + ServerSessionRequest (address 303). + """ + import os + + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.GET_VAR_SUBSTREAMED, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Value: 20-byte challenge as BLOB + challenge = os.urandom(20) + response += bytes([0x00, DataType.BLOB]) + response += encode_uint32_vlq(len(challenge)) + response += challenge + + # Trailing padding + response += struct.pack(">I", 0) + + return bytes(response) + + def _handle_set_variable(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle SetVariable -- used for legitimation response. + + Accepts the client's authentication response and marks the session + as authenticated. In this emulator, any response is accepted (we + don't verify the actual crypto). + """ + # Mark session as authenticated + self._authenticated_sessions[session_id] = True + logger.debug(f"Session {session_id} authenticated via SetVariable") + + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.SET_VARIABLE, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Trailing padding + response += struct.pack(">I", 0) return bytes(response) diff --git a/tests/test_s7commplus_tls.py b/tests/test_s7commplus_tls.py new file mode 100644 index 00000000..873d2a9f --- /dev/null +++ b/tests/test_s7commplus_tls.py @@ -0,0 +1,417 @@ +"""Integration tests for S7CommPlus V2/V3 with TLS. + +Tests the complete TLS connection flow including: +- V2 server + client with TLS and IntegrityId tracking +- V3 server + client with TLS +- Legitimation (password authentication) flow +- Async client with TLS +- Error handling for missing TLS + +Requires the `cryptography` package for self-signed certificate generation. +""" + +import struct +import tempfile +import time +from collections.abc import Generator +from pathlib import Path + +import pytest + +from snap7.s7commplus.server import S7CommPlusServer +from snap7.s7commplus.client import S7CommPlusClient +from snap7.s7commplus.protocol import ProtocolVersion + +try: + import ipaddress + + from cryptography import x509 + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + import datetime + + _has_cryptography = True +except ImportError: + _has_cryptography = False + +pytestmark = pytest.mark.skipif(not _has_cryptography, reason="requires cryptography package") + +# Use high ports to avoid conflicts +V2_TEST_PORT = 11130 +V3_TEST_PORT = 11131 +V2_AUTH_PORT = 11132 +V3_AUTH_PORT = 11133 + + +@pytest.fixture(scope="module") +def tls_certs() -> Generator[dict[str, str], None, None]: + """Generate self-signed TLS certificates for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + cert_path = str(Path(tmpdir) / "server.crt") + key_path = str(Path(tmpdir) / "server.key") + + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + with open(key_path, "wb") as f: + f.write( + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.NoEncryption(), + ) + ) + + yield {"cert": cert_path, "key": key_path} + + +def _make_server( + protocol_version: int, + password: str | None = None, +) -> S7CommPlusServer: + """Create and configure an S7CommPlus server with test data blocks.""" + srv = S7CommPlusServer(protocol_version=protocol_version, password=password) + + srv.register_db( + 1, + { + "temperature": ("Real", 0), + "pressure": ("Real", 4), + "running": ("Bool", 8), + "count": ("DInt", 10), + }, + ) + srv.register_raw_db(2, bytearray(256)) + + # Pre-populate DB1 + db1 = srv.get_db(1) + assert db1 is not None + struct.pack_into(">f", db1.data, 0, 23.5) + struct.pack_into(">f", db1.data, 4, 1.013) + db1.data[8] = 1 + struct.pack_into(">i", db1.data, 10, 42) + + return srv + + +@pytest.fixture() +def v2_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V2 server with TLS.""" + srv = _make_server(ProtocolVersion.V2) + srv.start(port=V2_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v3_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V3 server with TLS.""" + srv = _make_server(ProtocolVersion.V3) + srv.start(port=V3_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v2_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V2 server with TLS and password authentication.""" + srv = _make_server(ProtocolVersion.V2, password="secret123") + srv.start(port=V2_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v3_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V3 server with TLS and password authentication.""" + srv = _make_server(ProtocolVersion.V3, password="secret123") + srv.start(port=V3_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +class TestV2TLS: + """Test V2 protocol with TLS.""" + + def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V2 + client.disconnect() + assert not client.connected + + def test_read_real(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + client.db_write(1, 0, struct.pack(">f", 99.9)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 99.9) < 0.1 + finally: + client.disconnect() + + def test_multi_read(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + results = client.db_read_multi([ + (1, 0, 4), + (1, 4, 4), + (2, 0, 4), + ]) + assert len(results) == 3 + temp = struct.unpack(">f", results[0])[0] + assert abs(temp - 23.5) < 0.001 + finally: + client.disconnect() + + def test_explore(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + response = client.explore() + assert len(response) > 0 + finally: + client.disconnect() + + +class TestV3TLS: + """Test V3 protocol with TLS.""" + + def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V3 + client.disconnect() + assert not client.connected + + def test_read_real(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_write_and_read_back(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + client.db_write(1, 0, struct.pack(">f", 88.8)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 88.8) < 0.1 + finally: + client.disconnect() + + def test_multi_read(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + results = client.db_read_multi([ + (1, 0, 4), + (1, 10, 4), + ]) + assert len(results) == 2 + finally: + client.disconnect() + + def test_data_persists_across_clients(self, v3_server: S7CommPlusServer) -> None: + c1 = S7CommPlusClient() + c1.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + c1.db_write(2, 0, b"\xca\xfe\xba\xbe") + c1.disconnect() + + c2 = S7CommPlusClient() + c2.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + data = c2.db_read(2, 0, 4) + c2.disconnect() + + assert data == b"\xca\xfe\xba\xbe" + + +class TestV2WithoutTLS: + """Test that V2/V3 connections fail without TLS.""" + + def test_v2_server_requires_tls_on_client(self, tls_certs: dict[str, str]) -> None: + """V2 server reports V2 version; client should raise if no TLS.""" + # Start a V2 server WITHOUT TLS (so client can connect but gets V2 version) + srv = _make_server(ProtocolVersion.V2) + srv.start(port=V2_TEST_PORT + 10) + time.sleep(0.1) + try: + client = S7CommPlusClient() + with pytest.raises(Exception): + # The server reports V2 but client didn't use TLS + client.connect("127.0.0.1", port=V2_TEST_PORT + 10) + finally: + srv.stop() + + +class TestLegitimation: + """Test password authentication (legitimation).""" + + def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + try: + # Should be able to read after authentication + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_v3_connect_with_password(self, v3_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_v2_without_password_on_protected_server(self, v2_auth_server: S7CommPlusServer) -> None: + """Connecting without password to a password-protected server should fail data ops.""" + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True) + try: + # Server should reject data operations without authentication + with pytest.raises(Exception): + client.db_read(1, 0, 4) + finally: + client.disconnect() + + def test_v2_write_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + try: + client.db_write(1, 0, struct.pack(">f", 55.5)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 55.5) < 0.1 + finally: + client.disconnect() + + +@pytest.mark.asyncio +class TestAsyncV2TLS: + """Test async client with V2 TLS.""" + + async def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V2 + assert client.tls_active + await client.disconnect() + assert not client.connected + + async def test_read_real(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + + async def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + await client.db_write(1, 0, struct.pack(">f", 77.7)) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 77.7) < 0.1 + + +@pytest.mark.asyncio +class TestAsyncV3TLS: + """Test async client with V3 TLS.""" + + async def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + assert client.connected + assert client.protocol_version == ProtocolVersion.V3 + assert client.tls_active + await client.disconnect() + + async def test_read_write(self, v3_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + await client.db_write(2, 0, b"\xde\xad") + data = await client.db_read(2, 0, 2) + assert data == b"\xde\xad" + + +@pytest.mark.asyncio +class TestAsyncLegitimation: + """Test async client with password authentication.""" + + async def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001