diff --git a/pyproject.toml b/pyproject.toml index b865de58..ca23f81a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,9 +40,10 @@ discovery = ["pnio-dcp"] [tool.setuptools.package-data] snap7 = ["py.typed"] +s7 = ["py.typed"] [tool.setuptools.packages.find] -include = ["snap7*"] +include = ["snap7*", "s7*"] [project.scripts] snap7-server = "snap7.server:mainloop" diff --git a/s7/__init__.py b/s7/__init__.py new file mode 100644 index 00000000..1cfaa189 --- /dev/null +++ b/s7/__init__.py @@ -0,0 +1,35 @@ +"""Unified S7 communication library. + +Provides protocol-agnostic access to Siemens S7 PLCs with automatic +protocol discovery (S7CommPlus vs legacy S7). + +Usage:: + + from s7 import Client + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) +""" + +from .client import Client +from .async_client import AsyncClient +from .server import Server +from ._protocol import Protocol + +from snap7.type import Area, Block, WordLen, SrvEvent, SrvArea +from snap7.util.db import Row, DB + +__all__ = [ + "Client", + "AsyncClient", + "Server", + "Protocol", + "Area", + "Block", + "WordLen", + "SrvEvent", + "SrvArea", + "Row", + "DB", +] diff --git a/s7/_protocol.py b/s7/_protocol.py new file mode 100644 index 00000000..f1ca0328 --- /dev/null +++ b/s7/_protocol.py @@ -0,0 +1,19 @@ +"""Protocol enum for unified S7 client.""" + +from enum import Enum + + +class Protocol(Enum): + """S7 communication protocol selection. + + Used to control protocol auto-discovery in the unified client. + + Attributes: + AUTO: Try S7CommPlus first, fall back to legacy S7 if unsupported. + LEGACY: Use legacy S7 protocol only (S7-300/400, basic S7-1200/1500). + S7COMMPLUS: Use S7CommPlus protocol only (S7-1200/1500 with full access). + """ + + AUTO = "auto" + LEGACY = "legacy" + S7COMMPLUS = "s7commplus" diff --git a/s7/async_client.py b/s7/async_client.py new file mode 100644 index 00000000..df558f1f --- /dev/null +++ b/s7/async_client.py @@ -0,0 +1,246 @@ +"""Unified async S7 client with protocol auto-discovery. + +Provides a single async client that automatically selects the best protocol +(S7CommPlus or legacy S7) for communicating with Siemens S7 PLCs. + +Usage:: + + from s7 import AsyncClient + + async with AsyncClient() as client: + await client.connect("192.168.1.10", 0, 1) + data = await client.db_read(1, 0, 4) +""" + +import logging +from typing import Any, Optional + +from snap7.async_client import AsyncClient as LegacyAsyncClient +from snap7.s7commplus.async_client import S7CommPlusAsyncClient + +from ._protocol import Protocol + +logger = logging.getLogger(__name__) + + +class AsyncClient: + """Unified async S7 client with protocol auto-discovery. + + Async counterpart of :class:`s7.Client`. Automatically selects the + best protocol for the target PLC using asyncio for non-blocking I/O. + + Methods not explicitly defined are delegated to the underlying + legacy async client via ``__getattr__``. + + Examples:: + + from s7 import AsyncClient + + async with AsyncClient() as client: + await client.connect("192.168.1.10", 0, 1) + data = await client.db_read(1, 0, 4) + print(client.protocol) + """ + + def __init__(self) -> None: + self._legacy: Optional[LegacyAsyncClient] = None + self._plus: Optional[S7CommPlusAsyncClient] = None + self._protocol: Protocol = Protocol.AUTO + self._host: str = "" + self._port: int = 102 + self._rack: int = 0 + self._slot: int = 1 + + @property + def protocol(self) -> Protocol: + """The protocol currently in use for DB operations.""" + return self._protocol + + @property + def connected(self) -> bool: + """Whether the client is connected to a PLC.""" + if self._legacy is not None and self._legacy.connected: + return True + if self._plus is not None and self._plus.connected: + return True + return False + + async def connect( + self, + address: str, + rack: int = 0, + slot: int = 1, + tcp_port: int = 102, + *, + protocol: Protocol = Protocol.AUTO, + ) -> "AsyncClient": + """Connect to an S7 PLC. + + Args: + address: PLC IP address or hostname. + rack: PLC rack number. + slot: PLC slot number. + tcp_port: TCP port (default 102). + protocol: Protocol selection. AUTO tries S7CommPlus first, + then falls back to legacy S7. + + Returns: + self, for method chaining. + """ + self._host = address + self._port = tcp_port + self._rack = rack + self._slot = slot + + if protocol in (Protocol.AUTO, Protocol.S7COMMPLUS): + if await self._try_s7commplus(address, tcp_port, rack, slot): + self._protocol = Protocol.S7COMMPLUS + logger.info(f"Async connected to {address}:{tcp_port} using S7CommPlus") + else: + if protocol == Protocol.S7COMMPLUS: + raise RuntimeError( + f"S7CommPlus connection to {address}:{tcp_port} failed and protocol=S7COMMPLUS was explicitly requested" + ) + self._protocol = Protocol.LEGACY + logger.info(f"S7CommPlus not available, using legacy S7 for {address}:{tcp_port}") + else: + self._protocol = Protocol.LEGACY + + # Always connect legacy client + self._legacy = LegacyAsyncClient() + await self._legacy.connect(address, rack, slot, tcp_port) + logger.info(f"Async legacy S7 connected to {address}:{tcp_port}") + + return self + + async def _try_s7commplus( + self, + address: str, + tcp_port: int, + rack: int, + slot: int, + ) -> bool: + """Attempt async S7CommPlus connection and probe data operations.""" + plus = S7CommPlusAsyncClient() + try: + await plus.connect(host=address, port=tcp_port, rack=rack, slot=slot) + except Exception as e: + logger.debug(f"Async S7CommPlus connection failed: {e}") + return False + + if plus.using_legacy_fallback: + logger.debug("S7CommPlus connected but data ops not supported, disconnecting") + try: + await plus.disconnect() + except Exception: + pass + return False + + self._plus = plus + return True + + async def disconnect(self) -> int: + """Disconnect from the PLC. + + Returns: + 0 on success. + """ + if self._plus is not None: + try: + await self._plus.disconnect() + except Exception: + pass + self._plus = None + + if self._legacy is not None: + try: + await self._legacy.disconnect() + except Exception: + pass + self._legacy = None + + self._protocol = Protocol.AUTO + return 0 + + async def db_read(self, db_number: int, start: int, size: int) -> bytearray: + """Read data from a data block. + + Args: + db_number: DB number to read from. + start: Start byte offset. + size: Number of bytes to read. + + Returns: + Data read from the DB. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return bytearray(await self._plus.db_read(db_number, start, size)) + if self._legacy is not None: + return await self._legacy.db_read(db_number, start, size) + raise RuntimeError("Not connected") + + async def db_write(self, db_number: int, start: int, data: bytearray | bytes) -> int: + """Write data to a data block. + + Args: + db_number: DB number to write to. + start: Start byte offset. + data: Data to write. + + Returns: + 0 on success. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + await self._plus.db_write(db_number, start, bytes(data)) + return 0 + if self._legacy is not None: + return await self._legacy.db_write(db_number, start, bytearray(data)) + raise RuntimeError("Not connected") + + async def db_read_multi(self, items: list[tuple[int, int, int]]) -> list[bytearray]: + """Read multiple data block regions. + + Args: + items: List of (db_number, start_offset, size) tuples. + + Returns: + List of data for each item. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return [bytearray(r) for r in await self._plus.db_read_multi(items)] + if self._legacy is not None: + results = [] + for db, start, size in items: + results.append(await self._legacy.db_read(db, start, size)) + return results + raise RuntimeError("Not connected") + + async def explore(self) -> bytes: + """Browse the PLC object tree (S7CommPlus only). + + Returns: + Raw response payload. + """ + if self._plus is not None: + return await self._plus.explore() + raise RuntimeError("explore() requires S7CommPlus protocol") + + def __getattr__(self, name: str) -> Any: + """Delegate unknown attributes to the legacy async client.""" + if name.startswith("_"): + raise AttributeError(name) + legacy = self.__dict__.get("_legacy") + if legacy is not None: + return getattr(legacy, name) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + async def __aenter__(self) -> "AsyncClient": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.disconnect() + + def __repr__(self) -> str: + status = "connected" if self.connected else "disconnected" + proto = self._protocol.value if self.connected else "none" + return f"" diff --git a/s7/client.py b/s7/client.py new file mode 100644 index 00000000..8e143895 --- /dev/null +++ b/s7/client.py @@ -0,0 +1,314 @@ +"""Unified S7 client with protocol auto-discovery. + +Provides a single client that automatically selects the best protocol +(S7CommPlus or legacy S7) for communicating with Siemens S7 PLCs. + +Usage:: + + from s7 import Client + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + client.disconnect() +""" + +import logging +from typing import Any, Optional + +from snap7.client import Client as LegacyClient +from snap7.s7commplus.client import S7CommPlusClient + +from ._protocol import Protocol + +logger = logging.getLogger(__name__) + + +class Client: + """Unified S7 client with protocol auto-discovery. + + Automatically selects the best protocol for the target PLC: + + - **S7CommPlus**: For S7-1200/1500 PLCs with full engineering access + - **Legacy S7**: For S7-300/400 and S7-1200/1500 without S7CommPlus support + + When ``protocol=Protocol.AUTO`` (default), the client tries S7CommPlus + first and falls back to legacy S7 transparently. + + Exposes the full legacy S7 client API. Methods not explicitly defined + (block operations, PLC control, memory areas, etc.) are delegated to + the underlying legacy client via ``__getattr__``. + + Examples:: + + from s7 import Client + + # Auto-discover protocol + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + print(client.protocol) # Protocol.LEGACY or Protocol.S7COMMPLUS + + # Force legacy protocol + client = Client() + client.connect("192.168.1.10", 0, 1, protocol=Protocol.LEGACY) + + # S7CommPlus with TLS + client = Client() + client.connect("192.168.1.10", 0, 1, use_tls=True, password="secret") + """ + + def __init__(self) -> None: + self._legacy: Optional[LegacyClient] = None + self._plus: Optional[S7CommPlusClient] = None + self._protocol: Protocol = Protocol.AUTO + self._host: str = "" + self._port: int = 102 + self._rack: int = 0 + self._slot: int = 1 + + @property + def protocol(self) -> Protocol: + """The protocol currently in use for DB operations.""" + return self._protocol + + @property + def connected(self) -> bool: + """Whether the client is connected to a PLC.""" + if self._legacy is not None and self._legacy.connected: + return True + if self._plus is not None and self._plus.connected: + return True + return False + + def connect( + self, + address: str, + rack: int = 0, + slot: int = 1, + tcp_port: int = 102, + *, + protocol: Protocol = Protocol.AUTO, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + password: Optional[str] = None, + ) -> "Client": + """Connect to an S7 PLC. + + Args: + address: PLC IP address or hostname. + rack: PLC rack number. + slot: PLC slot number. + tcp_port: TCP port (default 102). + protocol: Protocol selection. AUTO tries S7CommPlus first, + then falls back to legacy S7. + use_tls: Enable TLS for S7CommPlus V2+ connections. + tls_cert: Path to client TLS certificate (PEM). + tls_key: Path to client private key (PEM). + tls_ca: Path to CA certificate for PLC verification (PEM). + password: PLC password for S7CommPlus legitimation. + + Returns: + self, for method chaining. + """ + self._host = address + self._port = tcp_port + self._rack = rack + self._slot = slot + + if protocol in (Protocol.AUTO, Protocol.S7COMMPLUS): + if self._try_s7commplus( + address, + tcp_port, + rack, + slot, + use_tls=use_tls, + tls_cert=tls_cert, + tls_key=tls_key, + tls_ca=tls_ca, + password=password, + ): + self._protocol = Protocol.S7COMMPLUS + logger.info(f"Connected to {address}:{tcp_port} using S7CommPlus") + else: + if protocol == Protocol.S7COMMPLUS: + raise RuntimeError( + f"S7CommPlus connection to {address}:{tcp_port} failed and protocol=S7COMMPLUS was explicitly requested" + ) + self._protocol = Protocol.LEGACY + logger.info(f"S7CommPlus not available, using legacy S7 for {address}:{tcp_port}") + else: + self._protocol = Protocol.LEGACY + + # Always connect legacy client (needed for block ops, PLC control, etc.) + self._legacy = LegacyClient() + self._legacy.connect(address, rack, slot, tcp_port) + logger.info(f"Legacy S7 connected to {address}:{tcp_port}") + + return self + + def _try_s7commplus( + self, + address: str, + tcp_port: int, + rack: int, + slot: int, + *, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + password: Optional[str] = None, + ) -> bool: + """Attempt S7CommPlus connection and probe data operations. + + Returns True if S7CommPlus data operations are fully supported. + On failure, cleans up the S7CommPlus connection. + """ + plus = S7CommPlusClient() + try: + plus.connect( + host=address, + port=tcp_port, + rack=rack, + slot=slot, + use_tls=use_tls, + tls_cert=tls_cert, + tls_key=tls_key, + tls_ca=tls_ca, + password=password, + ) + except Exception as e: + logger.debug(f"S7CommPlus connection failed: {e}") + return False + + # The S7CommPlus client probes data ops internally and sets + # using_legacy_fallback if they don't work. We don't want to use + # its built-in fallback (that would create a duplicate legacy + # connection), so we check and disconnect if it fell back. + if plus.using_legacy_fallback: + logger.debug("S7CommPlus connected but data ops not supported, disconnecting") + try: + plus.disconnect() + except Exception: + pass + return False + + self._plus = plus + return True + + def disconnect(self) -> int: + """Disconnect from the PLC. + + Returns: + 0 on success. + """ + if self._plus is not None: + try: + self._plus.disconnect() + except Exception: + pass + self._plus = None + + if self._legacy is not None: + try: + self._legacy.disconnect() + except Exception: + pass + self._legacy = None + + self._protocol = Protocol.AUTO + return 0 + + def db_read(self, db_number: int, start: int, size: int) -> bytearray: + """Read data from a data block. + + Uses S7CommPlus when available, otherwise legacy S7. + + Args: + db_number: DB number to read from. + start: Start byte offset. + size: Number of bytes to read. + + Returns: + Data read from the DB. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return bytearray(self._plus.db_read(db_number, start, size)) + if self._legacy is not None: + return self._legacy.db_read(db_number, start, size) + raise RuntimeError("Not connected") + + def db_write(self, db_number: int, start: int, data: bytearray | bytes) -> int: + """Write data to a data block. + + Uses S7CommPlus when available, otherwise legacy S7. + + Args: + db_number: DB number to write to. + start: Start byte offset. + data: Data to write. + + Returns: + 0 on success. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + self._plus.db_write(db_number, start, bytes(data)) + return 0 + if self._legacy is not None: + return self._legacy.db_write(db_number, start, bytearray(data)) + raise RuntimeError("Not connected") + + def db_read_multi(self, items: list[tuple[int, int, int]]) -> list[bytearray]: + """Read multiple data block regions. + + Uses S7CommPlus native multi-read when available, otherwise + performs individual reads via legacy S7. + + Args: + items: List of (db_number, start_offset, size) tuples. + + Returns: + List of data for each item. + """ + if self._protocol == Protocol.S7COMMPLUS and self._plus is not None: + return [bytearray(r) for r in self._plus.db_read_multi(items)] + if self._legacy is not None: + return [self._legacy.db_read(db, start, size) for db, start, size in items] + raise RuntimeError("Not connected") + + def explore(self) -> bytes: + """Browse the PLC object tree (S7CommPlus only). + + Returns: + Raw response payload. + + Raises: + RuntimeError: If S7CommPlus is not active. + """ + if self._plus is not None: + return self._plus.explore() + raise RuntimeError("explore() requires S7CommPlus protocol") + + def __getattr__(self, name: str) -> Any: + """Delegate unknown attributes to the legacy client.""" + # Avoid infinite recursion for attributes accessed during __init__ + if name.startswith("_"): + raise AttributeError(name) + legacy = self.__dict__.get("_legacy") + if legacy is not None: + return getattr(legacy, name) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + def __enter__(self) -> "Client": + return self + + def __exit__(self, *args: Any) -> None: + self.disconnect() + + def __repr__(self) -> str: + status = "connected" if self.connected else "disconnected" + proto = self._protocol.value if self.connected else "none" + return f"" diff --git a/s7/py.typed b/s7/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/s7/server.py b/s7/server.py new file mode 100644 index 00000000..e5ec4341 --- /dev/null +++ b/s7/server.py @@ -0,0 +1,181 @@ +"""Unified S7 server combining legacy S7 and S7CommPlus protocols. + +Provides a single server that can handle both legacy S7 and S7CommPlus +client connections simultaneously. + +Usage:: + + from s7 import Server + + server = Server() + server.register_area(SrvArea.DB, 1, db_data) + server.start(tcp_port=102) +""" + +import logging +from typing import Any, Optional + +from snap7.server import Server as LegacyServer +from snap7.s7commplus.server import S7CommPlusServer, DataBlock + +logger = logging.getLogger(__name__) + + +class Server: + """Unified S7 server combining legacy S7 and S7CommPlus. + + Wraps both a legacy S7 server and an S7CommPlus server, allowing + them to run side by side. Legacy clients connect on the primary port, + S7CommPlus clients on an optional secondary port. + + Methods not explicitly defined are delegated to the underlying + legacy server via ``__getattr__``. + + Examples:: + + from s7 import Server + from snap7.type import SrvArea + + server = Server() + + # Register memory areas for legacy S7 clients + db_data = bytearray(100) + server.register_area(SrvArea.DB, 1, db_data) + + # Register data blocks for S7CommPlus clients + server.register_db(1, {"temperature": ("Real", 0)}) + + # Start both servers + server.start(tcp_port=102, s7commplus_port=11020) + + # Stop both + server.stop() + """ + + def __init__(self, log: bool = True) -> None: + """Initialize unified server. + + Args: + log: Enable event logging for the legacy server. + """ + self._legacy: LegacyServer = LegacyServer(log=log) + self._plus: S7CommPlusServer = S7CommPlusServer() + self._plus_running: bool = False + + def start( + self, + tcp_port: int = 102, + s7commplus_port: Optional[int] = None, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + ) -> int: + """Start the server(s). + + Args: + tcp_port: Port for legacy S7 clients (default 102). + s7commplus_port: Optional port for S7CommPlus clients. + If None, only the legacy server is started. + use_tls: Enable TLS for S7CommPlus server. + tls_cert: Path to TLS certificate (PEM). + tls_key: Path to TLS private key (PEM). + + Returns: + 0 on success. + """ + result = self._legacy.start(tcp_port=tcp_port) + + if s7commplus_port is not None: + self._plus.start( + port=s7commplus_port, + use_tls=use_tls, + tls_cert=tls_cert, + tls_key=tls_key, + ) + self._plus_running = True + logger.info(f"S7CommPlus server started on port {s7commplus_port}") + + return result + + def stop(self) -> int: + """Stop both servers. + + Returns: + 0 on success. + """ + if self._plus_running: + try: + self._plus.stop() + except Exception: + pass + self._plus_running = False + + return self._legacy.stop() + + # -- S7CommPlus server methods -- + + def register_db( + self, + db_number: int, + variables: dict[str, tuple[str, int]], + size: int = 1024, + ) -> DataBlock: + """Register a data block on the S7CommPlus server. + + Args: + db_number: Data block number. + variables: Dict of {name: (type_name, byte_offset)}. + size: DB size in bytes (default 1024). + + Returns: + The registered DataBlock. + """ + return self._plus.register_db(db_number, variables, size) + + def register_raw_db(self, db_number: int, data: bytearray) -> DataBlock: + """Register a raw data block on the S7CommPlus server. + + Args: + db_number: Data block number. + data: Raw DB data. + + Returns: + The registered DataBlock. + """ + return self._plus.register_raw_db(db_number, data) + + def get_db(self, db_number: int) -> Optional[DataBlock]: + """Get a registered S7CommPlus data block. + + Args: + db_number: Data block number. + + Returns: + The DataBlock, or None if not registered. + """ + return self._plus.get_db(db_number) + + @property + def s7commplus_server(self) -> S7CommPlusServer: + """Direct access to the underlying S7CommPlus server.""" + return self._plus + + @property + def legacy_server(self) -> LegacyServer: + """Direct access to the underlying legacy S7 server.""" + return self._legacy + + def __getattr__(self, name: str) -> Any: + """Delegate unknown attributes to the legacy server.""" + if name.startswith("_"): + raise AttributeError(name) + legacy = self.__dict__.get("_legacy") + if legacy is not None: + return getattr(legacy, name) + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + def __enter__(self) -> "Server": + return self + + def __exit__(self, *args: Any) -> None: + self.stop() diff --git a/snap7/s7commplus/async_client.py b/snap7/s7commplus/async_client.py index e6a46fe2..e58e9fc4 100644 --- a/snap7/s7commplus/async_client.py +++ b/snap7/s7commplus/async_client.py @@ -4,6 +4,10 @@ Provides the same API as S7CommPlusClient but using asyncio for non-blocking I/O. Uses asyncio.Lock for concurrent safety. +Supports all S7CommPlus protocol versions (V1/V2/V3/TLS). The protocol +version is auto-detected from the PLC's CreateObject response during +connection setup. + When a PLC does not support S7CommPlus data operations, the client transparently falls back to the legacy S7 protocol for data block read/write operations (using synchronous calls in an executor). @@ -18,6 +22,7 @@ import asyncio import logging +import ssl import struct from typing import Any, Optional @@ -25,6 +30,7 @@ DataType, ElementID, FunctionCode, + LegitimationId, ObjectId, Opcode, ProtocolVersion, @@ -35,6 +41,7 @@ from .codec import encode_header, decode_header, encode_typed_value, encode_object_qualifier from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .client import _build_read_payload, _parse_read_response, _build_write_payload, _parse_write_response +from .connection import _element_size logger = logging.getLogger(__name__) @@ -47,7 +54,7 @@ class S7CommPlusAsyncClient: """Async S7CommPlus client for S7-1200/1500 PLCs. - Supports V1 and V2 protocols. V3/TLS planned for future. + Supports V1, V2, and V3 protocols (including TLS). Uses asyncio for all I/O operations and asyncio.Lock for concurrent safety when shared between multiple coroutines. @@ -76,6 +83,13 @@ def __init__(self) -> None: self._integrity_id_write: int = 0 self._with_integrity_id: bool = False + # TLS state + self._tls_active: bool = False + self._oms_secret: Optional[bytes] = None + + # Session setup + self._server_session_version: Optional[int] = None + @property def connected(self) -> bool: if self._use_legacy_data and self._legacy_client is not None: @@ -95,12 +109,22 @@ def using_legacy_fallback(self) -> bool: """Whether the client is using legacy S7 protocol for data operations.""" return self._use_legacy_data + @property + def tls_active(self) -> bool: + """Whether TLS encryption is active on this connection.""" + return self._tls_active + async def connect( self, host: str, port: int = 102, rack: int = 0, slot: int = 1, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + password: Optional[str] = None, ) -> None: """Connect to an S7-1200/1500 PLC. @@ -112,6 +136,11 @@ async def connect( port: TCP port (default 102) rack: PLC rack number slot: PLC slot number + use_tls: Whether to activate TLS (required for V2/V3) + tls_cert: Path to client TLS certificate (PEM) + tls_key: Path to client private key (PEM) + tls_ca: Path to CA certificate for PLC verification (PEM) + password: PLC password for legitimation (V2+ with TLS) """ self._host = host self._port = port @@ -128,14 +157,43 @@ async def connect( # InitSSL handshake await self._init_ssl() + # TLS activation (between InitSSL and CreateObject) + if use_tls: + await self._activate_tls(tls_cert=tls_cert, tls_key=tls_key, tls_ca=tls_ca) + # S7CommPlus session setup await self._create_session() + # Echo ServerSessionVersion back to complete handshake + if self._server_session_version is not None: + await self._setup_session() + else: + logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") + + # Version-specific post-setup + if self._protocol_version >= ProtocolVersion.V2: + if not self._tls_active: + raise RuntimeError( + f"PLC reports V{self._protocol_version} protocol but TLS is not active. " + "V2/V3 requires TLS. Use use_tls=True." + ) + self._with_integrity_id = True + self._integrity_id_read = 0 + self._integrity_id_write = 0 + logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") + self._connected = True logger.info( - f"Async S7CommPlus connected to {host}:{port}, version=V{self._protocol_version}, session={self._session_id}" + f"Async S7CommPlus connected to {host}:{port}, " + f"version=V{self._protocol_version}, session={self._session_id}, " + f"tls={self._tls_active}" ) + # Handle legitimation for password-protected PLCs + if password is not None and self._tls_active: + logger.info("Performing PLC legitimation (password authentication)") + await self.authenticate(password) + # Probe S7CommPlus data operations if not await self._probe_s7commplus_data(): logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") @@ -145,6 +203,116 @@ async def connect( await self.disconnect() raise + async def authenticate(self, password: str, username: str = "") -> None: + """Perform PLC password authentication (legitimation). + + Must be called after connect() and before data operations on + password-protected PLCs. Requires TLS to be active (V2+). + + Args: + password: PLC password + username: Username for new-style auth (optional) + """ + if not self._connected: + raise RuntimeError("Not connected") + + if not self._tls_active: + raise RuntimeError("Legitimation requires TLS. Connect with use_tls=True.") + + # Step 1: Get challenge from PLC + challenge = await self._get_legitimation_challenge() + logger.info(f"Received legitimation challenge ({len(challenge)} bytes)") + + # Step 2: Build response (auto-detect legacy vs new) + from .legitimation import build_legacy_response, build_new_response + + if username and self._oms_secret is not None: + response_data = build_new_response(password, challenge, self._oms_secret, username) + await self._send_legitimation_new(response_data) + elif self._oms_secret is not None: + try: + response_data = build_new_response(password, challenge, self._oms_secret, "") + await self._send_legitimation_new(response_data) + except NotImplementedError: + response_data = build_legacy_response(password, challenge) + await self._send_legitimation_legacy(response_data) + else: + logger.info("OMS secret not available, using legacy legitimation") + response_data = build_legacy_response(password, challenge) + await self._send_legitimation_legacy(response_data) + + logger.info("PLC legitimation completed successfully") + + async def _get_legitimation_challenge(self) -> bytes: + """Request legitimation challenge from PLC.""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_REQUEST) + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.GET_VAR_SUBSTREAMED, bytes(payload)) + + offset = 0 + return_value, consumed = decode_uint64_vlq(resp_payload, offset) + offset += consumed + + if return_value != 0: + raise RuntimeError(f"GetVarSubStreamed for challenge failed: return_value={return_value}") + + if offset + 2 > len(resp_payload): + raise RuntimeError("Challenge response too short") + + _flags = resp_payload[offset] + datatype = resp_payload[offset + 1] + offset += 2 + + if datatype == DataType.BLOB: + length, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + length]) + else: + count, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + count]) + + async def _send_legitimation_new(self, encrypted_response: bytes) -> None: + """Send new-style legitimation response (AES-256-CBC encrypted).""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.LEGITIMATE) + payload += bytes([0x00, DataType.BLOB]) + payload += encode_uint32_vlq(len(encrypted_response)) + payload += encrypted_response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + raise RuntimeError(f"Legitimation rejected by PLC: return_value={return_value}") + + async def _send_legitimation_legacy(self, response: bytes) -> None: + """Send legacy legitimation response (SHA-1 XOR).""" + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_RESPONSE) + payload += bytes([0x10, DataType.USINT]) + payload += encode_uint32_vlq(len(response)) + payload += response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + raise RuntimeError(f"Legacy legitimation rejected by PLC: return_value={return_value}") + async def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations.""" try: @@ -198,6 +366,9 @@ async def disconnect(self) -> None: self._with_integrity_id = False self._integrity_id_read = 0 self._integrity_id_write = 0 + self._tls_active = False + self._oms_secret = None + self._server_session_version = None if self._writer: try: @@ -407,6 +578,65 @@ async def _init_ssl(self) -> None: logger.debug(f"InitSSL response received, version=V{version}") + async def _activate_tls( + self, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + ) -> None: + """Activate TLS 1.3 over the COTP connection. + + Called after InitSSL and before CreateObject. Wraps the underlying + asyncio streams with TLS. + """ + if self._writer is None: + raise RuntimeError("Cannot activate TLS: not connected") + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; + # set_ciphers() only controls TLS 1.2 and below. + try: + ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + except ssl.SSLError: + pass + + if tls_cert and tls_key: + ctx.load_cert_chain(tls_cert, tls_key) + + if tls_ca: + ctx.load_verify_locations(tls_ca) + else: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + # Upgrade the connection to TLS. + # StreamWriter.start_tls() is the clean API (Python 3.11+). + # For Python 3.10, fall back to loop.start_tls() with the existing protocol. + if hasattr(self._writer, "start_tls"): + await self._writer.start_tls(ctx, server_hostname=self._host) + else: + transport = self._writer.transport + protocol = transport.get_protocol() + loop = asyncio.get_event_loop() + new_transport = await loop.start_tls(transport, protocol, ctx, server_hostname=self._host) + # Update writer's internal transport reference + self._writer._transport = new_transport + + self._tls_active = True + + # Extract OMS exporter secret for legitimation key derivation + ssl_object = self._writer.transport.get_extra_info("ssl_object") + if ssl_object is not None: + try: + self._oms_secret = ssl_object.export_keying_material("EXPERIMENTAL_OMS", 32, None) + logger.debug("OMS exporter secret extracted from TLS session") + except (AttributeError, ssl.SSLError) as e: + logger.warning(f"Could not extract OMS exporter secret: {e}") + self._oms_secret = None + + logger.info("TLS activated on async COTP connection") + async def _create_session(self) -> None: """Send CreateObject to establish S7CommPlus session.""" seq_num = self._next_sequence_number() @@ -455,7 +685,7 @@ async def _create_session(self) -> None: request += bytes([ElementID.TERMINATING_OBJECT]) request += struct.pack(">I", 0) - # Frame header + trailer + # Frame header + trailer (always V1 for CreateObject) frame = encode_header(ProtocolVersion.V1, len(request)) + request frame += struct.pack(">BBH", 0x72, ProtocolVersion.V1, 0x0000) await self._send_cotp_dt(frame) @@ -470,6 +700,178 @@ async def _create_session(self) -> None: self._session_id = struct.unpack_from(">I", response, 9)[0] self._protocol_version = version + logger.debug(f"Session created: id=0x{self._session_id:08X}, version=V{version}") + + # Parse response payload to extract ServerSessionVersion + self._parse_create_object_response(response[14:]) + + def _parse_create_object_response(self, payload: bytes) -> None: + """Parse CreateObject response to extract ServerSessionVersion.""" + offset = 0 + while offset < len(payload): + tag = payload[offset] + + if tag == ElementID.ATTRIBUTE: + offset += 1 + if offset >= len(payload): + break + attr_id, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + + if attr_id == ObjectId.SERVER_SESSION_VERSION: + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + if datatype in (DataType.UDINT, DataType.DWORD): + value, consumed = decode_uint32_vlq(payload, offset) + offset += consumed + self._server_session_version = value + logger.info(f"ServerSessionVersion = {value}") + return + else: + logger.debug(f"ServerSessionVersion has unexpected type {datatype:#04x}") + else: + if offset + 2 > len(payload): + break + _flags = payload[offset] + datatype = payload[offset + 1] + offset += 2 + offset = self._skip_typed_value(payload, offset, datatype, _flags) + + elif tag == ElementID.START_OF_OBJECT: + offset += 1 + if offset + 4 > len(payload): + break + offset += 4 # RelationId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassId + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # ClassFlags + _, consumed = decode_uint32_vlq(payload, offset) + offset += consumed # AttributeId + + elif tag == ElementID.TERMINATING_OBJECT: + offset += 1 + elif tag == 0x00: + offset += 1 + else: + offset += 1 + + logger.debug("ServerSessionVersion not found in CreateObject response") + + def _skip_typed_value(self, data: bytes, offset: int, datatype: int, flags: int) -> int: + """Skip over a typed value in the PObject tree.""" + is_array = bool(flags & 0x10) + + if is_array: + if offset >= len(data): + return offset + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + elem_size = _element_size(datatype) + if elem_size > 0: + offset += count * elem_size + else: + for _ in range(count): + if offset >= len(data): + break + _, consumed = decode_uint32_vlq(data, offset) + offset += consumed + return offset + + if datatype == DataType.NULL: + return offset + elif datatype in (DataType.BOOL, DataType.USINT, DataType.BYTE, DataType.SINT): + return offset + 1 + elif datatype in (DataType.UINT, DataType.WORD, DataType.INT): + return offset + 2 + elif datatype in (DataType.UDINT, DataType.DWORD, DataType.AID, DataType.DINT): + _, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + elif datatype in (DataType.ULINT, DataType.LWORD, DataType.LINT): + _, consumed = decode_uint64_vlq(data, offset) + return offset + consumed + elif datatype == DataType.REAL: + return offset + 4 + elif datatype == DataType.LREAL: + return offset + 8 + elif datatype == DataType.TIMESTAMP: + return offset + 8 + elif datatype == DataType.TIMESPAN: + _, consumed = decode_uint64_vlq(data, offset) + return offset + consumed + elif datatype == DataType.RID: + return offset + 4 + elif datatype in (DataType.BLOB, DataType.WSTRING): + length, consumed = decode_uint32_vlq(data, offset) + return offset + consumed + length + elif datatype == DataType.STRUCT: + count, consumed = decode_uint32_vlq(data, offset) + offset += consumed + for _ in range(count): + if offset + 2 > len(data): + break + sub_flags = data[offset] + sub_type = data[offset + 1] + offset += 2 + offset = self._skip_typed_value(data, offset, sub_type, sub_flags) + return offset + else: + return offset + + async def _setup_session(self) -> None: + """Send SetMultiVariables to echo ServerSessionVersion back to the PLC.""" + if self._server_session_version is None: + return + + seq_num = self._next_sequence_number() + + request = struct.pack( + ">BHHHHIB", + Opcode.REQUEST, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + self._session_id, + 0x36, + ) + + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(ObjectId.SERVER_SESSION_VERSION) + payload += encode_uint32_vlq(1) + payload += bytes([0x00, DataType.UDINT]) + payload += encode_uint32_vlq(self._server_session_version) + payload += bytes([0x00]) + payload += encode_object_qualifier() + payload += struct.pack(">I", 0) + + request += bytes(payload) + + frame = encode_header(self._protocol_version, len(request)) + request + frame += struct.pack(">BBH", 0x72, self._protocol_version, 0x0000) + await self._send_cotp_dt(frame) + + response_data = await self._recv_cotp_dt() + version, data_length, consumed = decode_header(response_data) + response = response_data[consumed : consumed + data_length] + + if len(response) < 14: + raise RuntimeError("SetupSession response too short") + + resp_payload = response[14:] + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value != 0: + logger.warning(f"SetupSession: PLC returned error {return_value}") + else: + logger.info("Session setup completed successfully") + async def _delete_session(self) -> None: """Send DeleteObject to close the session.""" seq_num = self._next_sequence_number() diff --git a/snap7/s7commplus/client.py b/snap7/s7commplus/client.py index 44112c9c..341c879f 100644 --- a/snap7/s7commplus/client.py +++ b/snap7/s7commplus/client.py @@ -14,7 +14,7 @@ the client transparently falls back to the legacy S7 protocol for data block read/write operations. -Status: V1 and V2 connections are functional. V3/TLS authentication planned. +Status: V1, V2, and V3 (including TLS) connections are functional. Reference: thomas-v2/S7CommPlusDriver (C#, LGPL-3.0) """ @@ -24,7 +24,7 @@ from typing import Any, Optional from .connection import S7CommPlusConnection -from .protocol import FunctionCode, Ids +from .protocol import FunctionCode, Ids, ProtocolVersion from .vlq import encode_uint32_vlq, decode_uint32_vlq, decode_uint64_vlq from .codec import ( encode_item_address, @@ -148,10 +148,12 @@ def connect( logger.info("Performing PLC legitimation (password authentication)") self._connection.authenticate(password) - # Probe S7CommPlus data operations with a minimal request - if not self._probe_s7commplus_data(): - logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") - self._setup_legacy_fallback() + # Probe S7CommPlus data operations with a minimal request. + # Skip probe for V2+ with TLS: TLS handshake confirms S7CommPlus support. + if self._connection.protocol_version < ProtocolVersion.V2: + if not self._probe_s7commplus_data(): + logger.info("S7CommPlus data operations not supported, falling back to legacy S7 protocol") + self._setup_legacy_fallback() def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations. diff --git a/snap7/s7commplus/connection.py b/snap7/s7commplus/connection.py index a60b44a0..ae44506d 100644 --- a/snap7/s7commplus/connection.py +++ b/snap7/s7commplus/connection.py @@ -85,8 +85,7 @@ class S7CommPlusConnection: - Version-appropriate authentication (V1/V2/V3/TLS) - Frame send/receive (TLS-encrypted when using V17+ firmware) - Currently implements V1 authentication. V2/V3/TLS authentication - layers are planned for future development. + Supports V1, V2, and V3 (including TLS) authentication. """ def __init__( @@ -202,21 +201,19 @@ def connect( logger.warning("PLC did not provide ServerSessionVersion - session setup incomplete") # Step 6: Version-specific post-setup - if self._protocol_version >= ProtocolVersion.V3: - if not use_tls: - logger.warning( - "PLC reports V3 protocol but TLS is not enabled. Connection may not work without use_tls=True." - ) - elif self._protocol_version == ProtocolVersion.V2: + if self._protocol_version >= ProtocolVersion.V2: if not self._tls_active: from ..error import S7ConnectionError - raise S7ConnectionError("PLC reports V2 protocol but TLS is not active. V2 requires TLS. Use use_tls=True.") + raise S7ConnectionError( + f"PLC reports V{self._protocol_version} protocol but TLS is not active. " + "V2/V3 requires TLS. Use use_tls=True." + ) # Enable IntegrityId tracking for V2+ self._with_integrity_id = True self._integrity_id_read = 0 self._integrity_id_write = 0 - logger.info("V2 IntegrityId tracking enabled") + logger.info(f"V{self._protocol_version} IntegrityId tracking enabled") # V1: No further authentication needed after CreateObject self._connected = True @@ -251,7 +248,7 @@ def authenticate(self, password: str, username: str = "") -> None: raise S7ConnectionError("Not connected") - if not self._tls_active or self._oms_secret is None: + if not self._tls_active: from ..error import S7ConnectionError raise S7ConnectionError("Legitimation requires TLS. Connect with use_tls=True.") @@ -263,11 +260,11 @@ def authenticate(self, password: str, username: str = "") -> None: # Step 2: Build response (auto-detect legacy vs new) from .legitimation import build_legacy_response, build_new_response - if username: + if username and self._oms_secret is not None: # New-style auth with username always uses AES-256-CBC response_data = build_new_response(password, challenge, self._oms_secret, username) self._send_legitimation_new(response_data) - else: + elif self._oms_secret is not None: # Try new-style first, fall back to legacy SHA-1 XOR try: response_data = build_new_response(password, challenge, self._oms_secret, "") @@ -276,6 +273,12 @@ def authenticate(self, password: str, username: str = "") -> None: # cryptography package not available, use legacy response_data = build_legacy_response(password, challenge) self._send_legitimation_legacy(response_data) + else: + # No OMS secret available (export_keying_material not supported), + # fall back to legacy SHA-1 XOR authentication + logger.info("OMS secret not available, using legacy legitimation") + response_data = build_legacy_response(password, challenge) + self._send_legitimation_legacy(response_data) logger.info("PLC legitimation completed successfully") @@ -1013,7 +1016,13 @@ def _setup_ssl_context( """ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.minimum_version = ssl.TLSVersion.TLSv1_3 - ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + # TLS 1.3 cipher suites are auto-negotiated on modern OpenSSL; + # set_ciphers() only controls TLS 1.2 and below. We try to set + # preferred ciphers but ignore failures (e.g. OpenSSL 3.x). + try: + ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + except ssl.SSLError: + pass if cert_path and key_path: ctx.load_cert_chain(cert_path, key_path) diff --git a/snap7/s7commplus/server.py b/snap7/s7commplus/server.py index 2af0d769..6719bf71 100644 --- a/snap7/s7commplus/server.py +++ b/snap7/s7commplus/server.py @@ -10,7 +10,7 @@ - Internal PLC memory model with thread-safe access - V2 protocol emulation with TLS and IntegrityId tracking -Supports both V1 (no TLS) and V2 (TLS + IntegrityId) emulation. +Supports V1 (no TLS), V2 (TLS + IntegrityId), and V3 (TLS + IntegrityId) emulation. Usage:: @@ -186,17 +186,21 @@ class S7CommPlusServer: Emulates an S7-1200/1500 PLC with: - Internal data block storage with named variables - - S7CommPlus protocol handling (V1 and V2) - - V2 TLS support with IntegrityId tracking + - S7CommPlus protocol handling (V1, V2, V3) + - V2/V3 TLS support with IntegrityId tracking + - Legitimation (password authentication) emulation - Multi-client support (threaded) - CPU state management """ - def __init__(self, protocol_version: int = ProtocolVersion.V1) -> None: + def __init__(self, protocol_version: int = ProtocolVersion.V1, password: Optional[str] = None) -> None: self._data_blocks: dict[int, DataBlock] = {} self._cpu_state = CPUState.RUN self._protocol_version = protocol_version + self._password = password self._next_session_id = 1 + # Per-client authentication state (session_id -> authenticated) + self._authenticated_sessions: dict[int, bool] = {} self._server_socket: Optional[socket.socket] = None self._server_thread: Optional[threading.Thread] = None @@ -205,7 +209,7 @@ def __init__(self, protocol_version: int = ProtocolVersion.V1) -> None: self._lock = threading.Lock() self._event_callback: Optional[Callable[..., None]] = None - # TLS configuration (V2) + # TLS configuration (V2/V3) self._ssl_context: Optional[ssl.SSLContext] = None self._use_tls: bool = False @@ -364,8 +368,12 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - session_id = 0 tls_activated = False # Per-client IntegrityId tracking (V2+) + # IntegrityId tracking starts AFTER the session setup (SetMultiVariables + # echoing ServerSessionVersion). The first SetMultiVariables after + # CreateObject is the session setup and doesn't include IntegrityId. integrity_id_read = 0 integrity_id_write = 0 + integrity_id_active = False while self._running: try: @@ -375,7 +383,9 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - break # Process the S7CommPlus request - response = self._process_request(data, session_id, integrity_id_read, integrity_id_write) + response = self._process_request( + data, session_id, integrity_id_read, integrity_id_write, integrity_id_active + ) if response is not None: # Check if session ID was assigned @@ -412,7 +422,12 @@ def _handle_client(self, client_sock: socket.socket, address: tuple[str, int]) - payload = data[hdr_consumed:] if len(payload) >= 14: func_code = struct.unpack_from(">H", payload, 3)[0] - if func_code in READ_FUNCTION_CODES: + if not integrity_id_active: + # First SetMultiVariables after CreateObject is session setup + # (no IntegrityId). Activate tracking after it. + if func_code == FunctionCode.SET_MULTI_VARIABLES: + integrity_id_active = True + elif func_code in READ_FUNCTION_CODES: integrity_id_read = (integrity_id_read + 1) & 0xFFFFFFFF elif func_code not in ( FunctionCode.INIT_SSL, @@ -522,6 +537,7 @@ def _process_request( session_id: int, integrity_id_read: int = 0, integrity_id_write: int = 0, + integrity_id_active: bool = False, ) -> Optional[bytes]: """Process an S7CommPlus request and return a response.""" if len(data) < 4: @@ -547,11 +563,14 @@ def _process_request( seq_num = struct.unpack_from(">H", payload, 7)[0] req_session_id = struct.unpack_from(">I", payload, 9)[0] - # For V2+, skip IntegrityId after the 14-byte header + # For V2+, skip IntegrityId after the 14-byte header. + # IntegrityId is only present after session setup is complete + # (integrity_id_active=True). The first SetMultiVariables after + # CreateObject is the session setup and doesn't include IntegrityId. request_offset = 14 if ( - self._protocol_version >= ProtocolVersion.V2 - and session_id != 0 + integrity_id_active + and self._protocol_version >= ProtocolVersion.V2 and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) ): if request_offset < len(payload): @@ -566,14 +585,41 @@ def _process_request( return self._handle_create_object(seq_num, request_data) elif function_code == FunctionCode.DELETE_OBJECT: return self._handle_delete_object(seq_num, req_session_id) + elif function_code == FunctionCode.GET_VAR_SUBSTREAMED: + response = self._handle_get_var_substreamed(seq_num, req_session_id, request_data) + elif function_code == FunctionCode.SET_VARIABLE: + response = self._handle_set_variable(seq_num, req_session_id, request_data) elif function_code == FunctionCode.EXPLORE: - return self._handle_explore(seq_num, req_session_id, request_data) + if not self._check_authenticated(req_session_id): + response = self._build_error_response(seq_num, req_session_id, function_code) + else: + response = self._handle_explore(seq_num, req_session_id, request_data) elif function_code == FunctionCode.GET_MULTI_VARIABLES: - return self._handle_get_multi_variables(seq_num, req_session_id, request_data) + if not self._check_authenticated(req_session_id): + response = self._build_error_response(seq_num, req_session_id, function_code) + else: + response = self._handle_get_multi_variables(seq_num, req_session_id, request_data) elif function_code == FunctionCode.SET_MULTI_VARIABLES: - return self._handle_set_multi_variables(seq_num, req_session_id, request_data) + # Auth check is inside the handler: session setup must bypass auth + response = self._handle_set_multi_variables( + seq_num, req_session_id, request_data, self._check_authenticated(req_session_id) + ) else: - return self._build_error_response(seq_num, req_session_id, function_code) + response = self._build_error_response(seq_num, req_session_id, function_code) + + # For V2+, insert IntegrityId right after the 14-byte response header. + # The client expects IntegrityId at offset 14 in the response, mirroring + # the request format. Only insert when IntegrityId tracking is active. + if ( + integrity_id_active + and self._protocol_version >= ProtocolVersion.V2 + and function_code not in (FunctionCode.INIT_SSL, FunctionCode.CREATE_OBJECT) + and response is not None + and len(response) >= 14 + ): + response = response[:14] + encode_uint32_vlq(0) + response[14:] + + return response def _build_response_header( self, @@ -798,16 +844,30 @@ def _handle_get_multi_variables(self, seq_num: int, session_id: int, request_dat # Terminate error list response += encode_uint32_vlq(0) - # IntegrityId - response += encode_uint32_vlq(0) - return bytes(response) - def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + def _handle_set_multi_variables( + self, seq_num: int, session_id: int, request_data: bytes, is_authenticated: bool = True + ) -> bytes: """Handle SetMultiVariables -- write variables to data blocks. + Also handles session setup (SetMultiVariables echoing ServerSessionVersion). + Session setup is detected by InObjectId matching the session_id. + Session setup always bypasses auth; data writes require authentication. + Reference: thomas-v2/S7CommPlusDriver/Core/SetMultiVariablesRequest.cs """ + # Check if this is a session setup write (InObjectId = session_id) + if len(request_data) >= 4: + in_object_id = struct.unpack_from(">I", request_data, 0)[0] + if in_object_id == session_id and session_id != 0: + # Session setup: just acknowledge success (no auth required) + return self._build_set_multi_response(seq_num, session_id, []) + + # For data writes, require authentication + if not is_authenticated: + return self._build_error_response(seq_num, session_id, FunctionCode.SET_MULTI_VARIABLES) + response = bytearray() response += struct.pack( ">BHHHHIB", @@ -843,8 +903,98 @@ def _handle_set_multi_variables(self, seq_num: int, session_id: int, request_dat # Terminate error list response += encode_uint32_vlq(0) - # IntegrityId - response += encode_uint32_vlq(0) + return bytes(response) + + def _build_set_multi_response( + self, seq_num: int, session_id: int, errors: list[tuple[int, int]] + ) -> bytes: + """Build a SetMultiVariables response.""" + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.SET_MULTI_VARIABLES, + 0x0000, + seq_num, + session_id, + 0x00, + ) + response += encode_uint64_vlq(0) # ReturnValue: success + for err_item, err_code in errors: + response += encode_uint32_vlq(err_item) + response += encode_uint64_vlq(err_code) + response += encode_uint32_vlq(0) # Terminate error list + return bytes(response) + + def _check_authenticated(self, session_id: int) -> bool: + """Check if a session is authenticated (or no password is required).""" + if self._password is None: + return True + return self._authenticated_sessions.get(session_id, False) + + def _handle_get_var_substreamed(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle GetVarSubStreamed -- used for legitimation challenge request. + + Returns a 20-byte random challenge when the client requests + ServerSessionRequest (address 303). + """ + import os + + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.GET_VAR_SUBSTREAMED, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Value: 20-byte challenge as BLOB + challenge = os.urandom(20) + response += bytes([0x00, DataType.BLOB]) + response += encode_uint32_vlq(len(challenge)) + response += challenge + + # Trailing padding + response += struct.pack(">I", 0) + + return bytes(response) + + def _handle_set_variable(self, seq_num: int, session_id: int, request_data: bytes) -> bytes: + """Handle SetVariable -- used for legitimation response. + + Accepts the client's authentication response and marks the session + as authenticated. In this emulator, any response is accepted (we + don't verify the actual crypto). + """ + # Mark session as authenticated + self._authenticated_sessions[session_id] = True + logger.debug(f"Session {session_id} authenticated via SetVariable") + + response = bytearray() + response += struct.pack( + ">BHHHHIB", + Opcode.RESPONSE, + 0x0000, + FunctionCode.SET_VARIABLE, + 0x0000, + seq_num, + session_id, + 0x00, + ) + + # ReturnValue: success + response += encode_uint64_vlq(0) + + # Trailing padding + response += struct.pack(">I", 0) return bytes(response) diff --git a/tests/test_s7_unified.py b/tests/test_s7_unified.py new file mode 100644 index 00000000..a58e7376 --- /dev/null +++ b/tests/test_s7_unified.py @@ -0,0 +1,443 @@ +"""Tests for the unified s7 package.""" + +import pytest +from unittest.mock import MagicMock, patch + +from s7 import Client, Server, Protocol, Area, Block, WordLen + + +class TestProtocol: + """Test Protocol enum.""" + + def test_enum_values(self) -> None: + assert Protocol.AUTO.value == "auto" + assert Protocol.LEGACY.value == "legacy" + assert Protocol.S7COMMPLUS.value == "s7commplus" + + +class TestClientInit: + """Test Client initialization.""" + + def test_default_state(self) -> None: + client = Client() + assert client.protocol == Protocol.AUTO + assert client.connected is False + + def test_repr_disconnected(self) -> None: + client = Client() + assert "disconnected" in repr(client) + + +class TestClientLegacy: + """Test Client with legacy protocol (mocked).""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_connect_legacy_fallback(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """When S7CommPlus fails, client falls back to legacy.""" + # S7CommPlus connect raises + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("Connection refused") + mock_plus_cls.return_value = mock_plus + + # Legacy connects fine + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + result = client.connect("192.168.1.10", 0, 1) + + assert result is client + assert client.protocol == Protocol.LEGACY + mock_legacy.connect.assert_called_once_with("192.168.1.10", 0, 1, 102) + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_connect_explicit_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """When protocol=LEGACY, S7CommPlus is not attempted.""" + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1, protocol=Protocol.LEGACY) + + assert client.protocol == Protocol.LEGACY + mock_plus_cls.assert_not_called() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_read_delegates_to_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """db_read delegates to legacy when protocol is LEGACY.""" + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.db_read.return_value = bytearray([1, 2, 3, 4]) + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + + assert data == bytearray([1, 2, 3, 4]) + mock_legacy.db_read.assert_called_once_with(1, 0, 4) + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_write_delegates_to_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.db_write.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + result = client.db_write(1, 0, bytearray([1, 2, 3, 4])) + + assert result == 0 + mock_legacy.db_write.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_getattr_delegates_to_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """Methods not on unified client delegate to legacy.""" + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.get_cpu_info.return_value = "cpu_info" + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + info = client.get_cpu_info() + + assert info == "cpu_info" + mock_legacy.get_cpu_info.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_disconnect(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + client.connect("192.168.1.10", 0, 1) + result = client.disconnect() + + assert result == 0 + mock_legacy.disconnect.assert_called_once() + + +class TestClientS7CommPlus: + """Test Client with S7CommPlus protocol (mocked).""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_connect_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """When S7CommPlus succeeds, protocol is S7COMMPLUS.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + + assert client.protocol == Protocol.S7COMMPLUS + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_read_uses_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """db_read uses S7CommPlus when available.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus.db_read.return_value = b"\x01\x02\x03\x04" + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + data = client.db_read(1, 0, 4) + + assert data == bytearray([1, 2, 3, 4]) + assert isinstance(data, bytearray) + mock_plus.db_read.assert_called_once_with(1, 0, 4) + mock_legacy.db_read.assert_not_called() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_db_write_uses_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + result = client.db_write(1, 0, bytearray([1, 2, 3, 4])) + + assert result == 0 + mock_plus.db_write.assert_called_once_with(1, 0, b"\x01\x02\x03\x04") + mock_legacy.db_write.assert_not_called() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_legacy_methods_still_delegate(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """Even with S7CommPlus, legacy-only methods go to legacy client.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.list_blocks.return_value = "blocks" + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + blocks = client.list_blocks() + + assert blocks == "blocks" + mock_legacy.list_blocks.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_s7commplus_fallback_detected(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """If S7CommPlus connects but falls back, unified client uses LEGACY.""" + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = True # Data ops not supported + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + + assert client.protocol == Protocol.LEGACY + # S7CommPlus should have been disconnected + mock_plus.disconnect.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_explore_requires_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """explore() raises when S7CommPlus is not active.""" + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + + with pytest.raises(RuntimeError, match="S7CommPlus"): + client.explore() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_explicit_s7commplus_fails(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + """protocol=S7COMMPLUS raises if S7CommPlus fails.""" + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + client = Client() + with pytest.raises(RuntimeError, match="explicitly requested"): + client.connect("192.168.1.10", 0, 1, protocol=Protocol.S7COMMPLUS) + + +class TestClientContextManager: + """Test context manager protocol.""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_context_manager(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + with Client() as client: + client.connect("192.168.1.10", 0, 1) + assert client.connected + + mock_legacy.disconnect.assert_called_once() + + +class TestClientDbReadMulti: + """Test db_read_multi.""" + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_multi_read_s7commplus(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.using_legacy_fallback = False + mock_plus.connected = True + mock_plus.db_read_multi.return_value = [b"\x01\x02", b"\x03\x04"] + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + results = client.db_read_multi([(1, 0, 2), (1, 2, 2)]) + + assert len(results) == 2 + assert all(isinstance(r, bytearray) for r in results) + mock_plus.db_read_multi.assert_called_once() + + @patch("s7.client.S7CommPlusClient") + @patch("s7.client.LegacyClient") + def test_multi_read_legacy(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_plus = MagicMock() + mock_plus.connect.side_effect = RuntimeError("fail") + mock_plus_cls.return_value = mock_plus + + mock_legacy = MagicMock() + mock_legacy.connected = True + mock_legacy.db_read.side_effect = [bytearray([1, 2]), bytearray([3, 4])] + mock_legacy_cls.return_value = mock_legacy + + client = Client() + client.connect("192.168.1.10", 0, 1) + results = client.db_read_multi([(1, 0, 2), (1, 2, 2)]) + + assert len(results) == 2 + assert mock_legacy.db_read.call_count == 2 + + +class TestImports: + """Test that s7 package exports are accessible.""" + + def test_types_exported(self) -> None: + assert Area is not None + assert Block is not None + assert WordLen is not None + + def test_protocol_exported(self) -> None: + assert Protocol.AUTO is not None + + +class TestServer: + """Test unified Server (mocked).""" + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_start_legacy_only(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + result = server.start(tcp_port=11020) + + assert result == 0 + mock_legacy.start.assert_called_once_with(tcp_port=11020) + mock_plus.start.assert_not_called() + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_start_both(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + server.start(tcp_port=11020, s7commplus_port=11021) + + mock_legacy.start.assert_called_once_with(tcp_port=11020) + mock_plus.start.assert_called_once_with(port=11021, use_tls=False, tls_cert=None, tls_key=None) + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_stop_both(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy.stop.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + server.start(tcp_port=11020, s7commplus_port=11021) + result = server.stop() + + assert result == 0 + mock_plus.stop.assert_called_once() + mock_legacy.stop.assert_called_once() + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_getattr_delegates(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.get_status.return_value = ("running", "ok", 0) + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + server = Server() + status = server.get_status() + + assert status == ("running", "ok", 0) + + @patch("s7.server.S7CommPlusServer") + @patch("s7.server.LegacyServer") + def test_context_manager(self, mock_legacy_cls: MagicMock, mock_plus_cls: MagicMock) -> None: + mock_legacy = MagicMock() + mock_legacy.start.return_value = 0 + mock_legacy.stop.return_value = 0 + mock_legacy_cls.return_value = mock_legacy + + mock_plus = MagicMock() + mock_plus_cls.return_value = mock_plus + + with Server() as server: + server.start(tcp_port=11020) + + mock_legacy.stop.assert_called_once() diff --git a/tests/test_s7commplus_tls.py b/tests/test_s7commplus_tls.py new file mode 100644 index 00000000..873d2a9f --- /dev/null +++ b/tests/test_s7commplus_tls.py @@ -0,0 +1,417 @@ +"""Integration tests for S7CommPlus V2/V3 with TLS. + +Tests the complete TLS connection flow including: +- V2 server + client with TLS and IntegrityId tracking +- V3 server + client with TLS +- Legitimation (password authentication) flow +- Async client with TLS +- Error handling for missing TLS + +Requires the `cryptography` package for self-signed certificate generation. +""" + +import struct +import tempfile +import time +from collections.abc import Generator +from pathlib import Path + +import pytest + +from snap7.s7commplus.server import S7CommPlusServer +from snap7.s7commplus.client import S7CommPlusClient +from snap7.s7commplus.protocol import ProtocolVersion + +try: + import ipaddress + + from cryptography import x509 + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + import datetime + + _has_cryptography = True +except ImportError: + _has_cryptography = False + +pytestmark = pytest.mark.skipif(not _has_cryptography, reason="requires cryptography package") + +# Use high ports to avoid conflicts +V2_TEST_PORT = 11130 +V3_TEST_PORT = 11131 +V2_AUTH_PORT = 11132 +V3_AUTH_PORT = 11133 + + +@pytest.fixture(scope="module") +def tls_certs() -> Generator[dict[str, str], None, None]: + """Generate self-signed TLS certificates for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + cert_path = str(Path(tmpdir) / "server.crt") + key_path = str(Path(tmpdir) / "server.key") + + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + with open(key_path, "wb") as f: + f.write( + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.NoEncryption(), + ) + ) + + yield {"cert": cert_path, "key": key_path} + + +def _make_server( + protocol_version: int, + password: str | None = None, +) -> S7CommPlusServer: + """Create and configure an S7CommPlus server with test data blocks.""" + srv = S7CommPlusServer(protocol_version=protocol_version, password=password) + + srv.register_db( + 1, + { + "temperature": ("Real", 0), + "pressure": ("Real", 4), + "running": ("Bool", 8), + "count": ("DInt", 10), + }, + ) + srv.register_raw_db(2, bytearray(256)) + + # Pre-populate DB1 + db1 = srv.get_db(1) + assert db1 is not None + struct.pack_into(">f", db1.data, 0, 23.5) + struct.pack_into(">f", db1.data, 4, 1.013) + db1.data[8] = 1 + struct.pack_into(">i", db1.data, 10, 42) + + return srv + + +@pytest.fixture() +def v2_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V2 server with TLS.""" + srv = _make_server(ProtocolVersion.V2) + srv.start(port=V2_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v3_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V3 server with TLS.""" + srv = _make_server(ProtocolVersion.V3) + srv.start(port=V3_TEST_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v2_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V2 server with TLS and password authentication.""" + srv = _make_server(ProtocolVersion.V2, password="secret123") + srv.start(port=V2_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +@pytest.fixture() +def v3_auth_server(tls_certs: dict[str, str]) -> Generator[S7CommPlusServer, None, None]: + """V3 server with TLS and password authentication.""" + srv = _make_server(ProtocolVersion.V3, password="secret123") + srv.start(port=V3_AUTH_PORT, use_tls=True, tls_cert=tls_certs["cert"], tls_key=tls_certs["key"]) + time.sleep(0.1) + yield srv + srv.stop() + + +class TestV2TLS: + """Test V2 protocol with TLS.""" + + def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V2 + client.disconnect() + assert not client.connected + + def test_read_real(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + client.db_write(1, 0, struct.pack(">f", 99.9)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 99.9) < 0.1 + finally: + client.disconnect() + + def test_multi_read(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + results = client.db_read_multi([ + (1, 0, 4), + (1, 4, 4), + (2, 0, 4), + ]) + assert len(results) == 3 + temp = struct.unpack(">f", results[0])[0] + assert abs(temp - 23.5) < 0.001 + finally: + client.disconnect() + + def test_explore(self, v2_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + try: + response = client.explore() + assert len(response) > 0 + finally: + client.disconnect() + + +class TestV3TLS: + """Test V3 protocol with TLS.""" + + def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V3 + client.disconnect() + assert not client.connected + + def test_read_real(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_write_and_read_back(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + client.db_write(1, 0, struct.pack(">f", 88.8)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 88.8) < 0.1 + finally: + client.disconnect() + + def test_multi_read(self, v3_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + try: + results = client.db_read_multi([ + (1, 0, 4), + (1, 10, 4), + ]) + assert len(results) == 2 + finally: + client.disconnect() + + def test_data_persists_across_clients(self, v3_server: S7CommPlusServer) -> None: + c1 = S7CommPlusClient() + c1.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + c1.db_write(2, 0, b"\xca\xfe\xba\xbe") + c1.disconnect() + + c2 = S7CommPlusClient() + c2.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + data = c2.db_read(2, 0, 4) + c2.disconnect() + + assert data == b"\xca\xfe\xba\xbe" + + +class TestV2WithoutTLS: + """Test that V2/V3 connections fail without TLS.""" + + def test_v2_server_requires_tls_on_client(self, tls_certs: dict[str, str]) -> None: + """V2 server reports V2 version; client should raise if no TLS.""" + # Start a V2 server WITHOUT TLS (so client can connect but gets V2 version) + srv = _make_server(ProtocolVersion.V2) + srv.start(port=V2_TEST_PORT + 10) + time.sleep(0.1) + try: + client = S7CommPlusClient() + with pytest.raises(Exception): + # The server reports V2 but client didn't use TLS + client.connect("127.0.0.1", port=V2_TEST_PORT + 10) + finally: + srv.stop() + + +class TestLegitimation: + """Test password authentication (legitimation).""" + + def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + try: + # Should be able to read after authentication + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_v3_connect_with_password(self, v3_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V3_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + try: + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + finally: + client.disconnect() + + def test_v2_without_password_on_protected_server(self, v2_auth_server: S7CommPlusServer) -> None: + """Connecting without password to a password-protected server should fail data ops.""" + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True) + try: + # Server should reject data operations without authentication + with pytest.raises(Exception): + client.db_read(1, 0, 4) + finally: + client.disconnect() + + def test_v2_write_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + client = S7CommPlusClient() + client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + try: + client.db_write(1, 0, struct.pack(">f", 55.5)) + data = client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 55.5) < 0.1 + finally: + client.disconnect() + + +@pytest.mark.asyncio +class TestAsyncV2TLS: + """Test async client with V2 TLS.""" + + async def test_connect_disconnect(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + assert client.connected + assert client.session_id != 0 + assert client.protocol_version == ProtocolVersion.V2 + assert client.tls_active + await client.disconnect() + assert not client.connected + + async def test_read_real(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001 + + async def test_write_and_read_back(self, v2_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_TEST_PORT, use_tls=True) + await client.db_write(1, 0, struct.pack(">f", 77.7)) + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 77.7) < 0.1 + + +@pytest.mark.asyncio +class TestAsyncV3TLS: + """Test async client with V3 TLS.""" + + async def test_connect_disconnect(self, v3_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + assert client.connected + assert client.protocol_version == ProtocolVersion.V3 + assert client.tls_active + await client.disconnect() + + async def test_read_write(self, v3_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V3_TEST_PORT, use_tls=True) + await client.db_write(2, 0, b"\xde\xad") + data = await client.db_read(2, 0, 2) + assert data == b"\xde\xad" + + +@pytest.mark.asyncio +class TestAsyncLegitimation: + """Test async client with password authentication.""" + + async def test_v2_connect_with_password(self, v2_auth_server: S7CommPlusServer) -> None: + from snap7.s7commplus.async_client import S7CommPlusAsyncClient + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=V2_AUTH_PORT, use_tls=True, password="secret123") + assert client.connected + data = await client.db_read(1, 0, 4) + value = struct.unpack(">f", data)[0] + assert abs(value - 23.5) < 0.001