From 48ce81559c9e539c21f1c4e4cd7e7ba37384ee73 Mon Sep 17 00:00:00 2001 From: Gijs Molenaar Date: Tue, 24 Mar 2026 18:51:06 +0200 Subject: [PATCH 1/2] Add TLS support to S7CommPlus async client Implement TLS 1.3 for S7CommPlusAsyncClient, bringing feature parity with the sync S7CommPlusConnection for V2 protocol connections: - Add use_tls, tls_cert, tls_key, tls_ca parameters to connect() - Implement _activate_tls() using asyncio start_tls() for in-place transport upgrade - Add authenticate() method with full legitimation support (legacy SHA-1 XOR and new AES-256-CBC modes) - Add V2 post-connection checks (require TLS, enable IntegrityId) - Reset TLS state on disconnect - Fix TLS 1.3 cipher configuration for Python 3.13+ compatibility (use set_ciphersuites instead of set_ciphers for TLS 1.3) The cipher fix also applies to the sync connection.py to prevent the same issue on Python 3.14+. Co-Authored-By: Claude Opus 4.6 --- snap7/s7commplus/async_client.py | 279 ++++++++++++++++++++++++++++++- snap7/s7commplus/connection.py | 6 +- tests/test_async_tls.py | 236 ++++++++++++++++++++++++++ 3 files changed, 515 insertions(+), 6 deletions(-) create mode 100644 tests/test_async_tls.py diff --git a/snap7/s7commplus/async_client.py b/snap7/s7commplus/async_client.py index e6a46fe2..08cca180 100644 --- a/snap7/s7commplus/async_client.py +++ b/snap7/s7commplus/async_client.py @@ -18,6 +18,7 @@ import asyncio import logging +import ssl import struct from typing import Any, Optional @@ -47,7 +48,9 @@ class S7CommPlusAsyncClient: """Async S7CommPlus client for S7-1200/1500 PLCs. - Supports V1 and V2 protocols. V3/TLS planned for future. + Supports all S7CommPlus protocol versions (V1/V2/V3/TLS). The protocol + version is auto-detected from the PLC's CreateObject response during + connection setup. Uses asyncio for all I/O operations and asyncio.Lock for concurrent safety when shared between multiple coroutines. @@ -76,6 +79,11 @@ def __init__(self) -> None: self._integrity_id_write: int = 0 self._with_integrity_id: bool = False + # TLS state + self._tls_active: bool = False + self._oms_secret: Optional[bytes] = None + self._server_session_version: Optional[int] = None + @property def connected(self) -> bool: if self._use_legacy_data and self._legacy_client is not None: @@ -95,15 +103,37 @@ def using_legacy_fallback(self) -> bool: """Whether the client is using legacy S7 protocol for data operations.""" return self._use_legacy_data + @property + def tls_active(self) -> bool: + """Whether TLS is active on the connection.""" + return self._tls_active + + @property + def oms_secret(self) -> Optional[bytes]: + """OMS exporter secret from TLS session (None if TLS not active).""" + return self._oms_secret + async def connect( self, host: str, port: int = 102, rack: int = 0, slot: int = 1, + *, + use_tls: bool = False, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, ) -> None: """Connect to an S7-1200/1500 PLC. + The connection sequence: + 1. COTP connection (same as legacy S7comm) + 2. InitSSL handshake + 3. TLS activation (if use_tls=True, required for V2) + 4. CreateObject to establish S7CommPlus session + 5. Enable IntegrityId tracking (V2+) + If the PLC does not support S7CommPlus data operations, a secondary legacy S7 connection is established transparently for data access. @@ -112,6 +142,10 @@ async def connect( port: TCP port (default 102) rack: PLC rack number slot: PLC slot number + use_tls: Whether to activate TLS after InitSSL. + tls_cert: Path to client TLS certificate (PEM) + tls_key: Path to client private key (PEM) + tls_ca: Path to CA certificate for PLC verification (PEM) """ self._host = host self._port = port @@ -122,18 +156,43 @@ async def connect( self._reader, self._writer = await asyncio.open_connection(host, port) try: - # COTP handshake with S7CommPlus TSAP values + # Step 1: COTP handshake with S7CommPlus TSAP values await self._cotp_connect(S7COMMPLUS_LOCAL_TSAP, S7COMMPLUS_REMOTE_TSAP) - # InitSSL handshake + # Step 2: InitSSL handshake await self._init_ssl() - # S7CommPlus session setup + # Step 3: TLS activation (between InitSSL and CreateObject) + if use_tls: + await self._activate_tls(tls_cert=tls_cert, tls_key=tls_key, tls_ca=tls_ca) + + # Step 4: S7CommPlus session setup await self._create_session() + # Step 5: Version-specific post-setup + if self._protocol_version >= ProtocolVersion.V3: + if not use_tls: + logger.warning( + "PLC reports V3 protocol but TLS is not enabled. Connection may not work without use_tls=True." + ) + elif self._protocol_version == ProtocolVersion.V2: + if not self._tls_active: + from ..error import S7ConnectionError + + raise S7ConnectionError( + "PLC reports V2 protocol but TLS is not active. V2 requires TLS. Use use_tls=True." + ) + # Enable IntegrityId tracking for V2+ + self._with_integrity_id = True + self._integrity_id_read = 0 + self._integrity_id_write = 0 + logger.info("V2 IntegrityId tracking enabled") + self._connected = True logger.info( - f"Async S7CommPlus connected to {host}:{port}, version=V{self._protocol_version}, session={self._session_id}" + f"Async S7CommPlus connected to {host}:{port}, " + f"version=V{self._protocol_version}, session={self._session_id}, " + f"tls={self._tls_active}" ) # Probe S7CommPlus data operations @@ -145,6 +204,213 @@ async def connect( await self.disconnect() raise + async def authenticate(self, password: str, username: str = "") -> None: + """Perform PLC password authentication (legitimation). + + Must be called after connect() and before data operations on + password-protected PLCs. Requires TLS to be active (V2+). + + The method auto-detects legacy vs new legitimation based on + the PLC's firmware version. + + Args: + password: PLC password + username: Username for new-style auth (optional) + + Raises: + S7ConnectionError: If not connected, TLS not active, or auth fails + """ + if not self._connected: + from ..error import S7ConnectionError + + raise S7ConnectionError("Not connected") + + if not self._tls_active or self._oms_secret is None: + from ..error import S7ConnectionError + + raise S7ConnectionError("Legitimation requires TLS. Connect with use_tls=True.") + + # Step 1: Get challenge from PLC + challenge = await self._get_legitimation_challenge() + logger.info(f"Received legitimation challenge ({len(challenge)} bytes)") + + # Step 2: Build response (auto-detect legacy vs new) + from .legitimation import build_legacy_response, build_new_response + + if username: + response_data = build_new_response(password, challenge, self._oms_secret, username) + await self._send_legitimation_new(response_data) + else: + try: + response_data = build_new_response(password, challenge, self._oms_secret, "") + await self._send_legitimation_new(response_data) + except NotImplementedError: + response_data = build_legacy_response(password, challenge) + await self._send_legitimation_legacy(response_data) + + logger.info("PLC legitimation completed successfully") + + async def _activate_tls( + self, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + tls_ca: Optional[str] = None, + ) -> None: + """Activate TLS 1.3 over the COTP connection. + + Called after InitSSL and before CreateObject. Uses asyncio's + start_tls() to upgrade the existing connection to TLS. + + Args: + tls_cert: Path to client TLS certificate (PEM) + tls_key: Path to client private key (PEM) + tls_ca: Path to CA certificate for PLC verification (PEM) + """ + if self._writer is None: + from ..error import S7ConnectionError + + raise S7ConnectionError("Cannot activate TLS: not connected") + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + + # TLS 1.3 ciphersuites are configured differently from TLS 1.2 + if hasattr(ctx, "set_ciphersuites"): + ctx.set_ciphersuites("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + # If set_ciphersuites not available, TLS 1.3 uses its mandatory defaults + + if tls_cert and tls_key: + ctx.load_cert_chain(tls_cert, tls_key) + + if tls_ca: + ctx.load_verify_locations(tls_ca) + else: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + # Upgrade existing transport to TLS using asyncio start_tls + transport = self._writer.transport + loop = asyncio.get_event_loop() + new_transport = await loop.start_tls( + transport, + transport.get_protocol(), + ctx, + server_hostname=self._host, + ) + + # Update reader/writer to use the TLS transport + self._writer._transport = new_transport + self._tls_active = True + + # Extract OMS exporter secret for legitimation key derivation + if new_transport is None: + from ..error import S7ConnectionError + + raise S7ConnectionError("TLS handshake failed: no transport returned") + + ssl_object = new_transport.get_extra_info("ssl_object") + if ssl_object is not None: + try: + self._oms_secret = ssl_object.export_keying_material("EXPERIMENTAL_OMS", 32, None) + logger.debug("OMS exporter secret extracted from TLS session") + except (AttributeError, ssl.SSLError) as e: + logger.warning(f"Could not extract OMS exporter secret: {e}") + self._oms_secret = None + + logger.info("TLS 1.3 activated on async COTP connection") + + async def _get_legitimation_challenge(self) -> bytes: + """Request legitimation challenge from PLC. + + Sends GetVarSubStreamed with address ServerSessionRequest (303). + + Returns: + Challenge bytes from PLC + """ + from .protocol import LegitimationId, DataType as DT + + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_REQUEST) + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.GET_VAR_SUBSTREAMED, bytes(payload)) + + offset = 0 + return_value, consumed = decode_uint64_vlq(resp_payload, offset) + offset += consumed + + if return_value != 0: + from ..error import S7ConnectionError + + raise S7ConnectionError(f"GetVarSubStreamed for challenge failed: return_value={return_value}") + + if offset + 2 > len(resp_payload): + from ..error import S7ConnectionError + + raise S7ConnectionError("Challenge response too short") + + _flags = resp_payload[offset] + datatype = resp_payload[offset + 1] + offset += 2 + + if datatype == DT.BLOB: + length, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + length]) + else: + count, consumed = decode_uint32_vlq(resp_payload, offset) + offset += consumed + return bytes(resp_payload[offset : offset + count]) + + async def _send_legitimation_new(self, encrypted_response: bytes) -> None: + """Send new-style legitimation response (AES-256-CBC encrypted).""" + from .protocol import LegitimationId, DataType as DT + + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.LEGITIMATE) + payload += bytes([0x00, DT.BLOB]) + payload += encode_uint32_vlq(len(encrypted_response)) + payload += encrypted_response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + from ..error import S7ConnectionError + + raise S7ConnectionError(f"Legitimation rejected by PLC: return_value={return_value}") + logger.debug(f"New legitimation return_value={return_value}") + + async def _send_legitimation_legacy(self, response: bytes) -> None: + """Send legacy legitimation response (SHA-1 XOR).""" + from .protocol import LegitimationId, DataType as DT + + payload = bytearray() + payload += struct.pack(">I", self._session_id) + payload += encode_uint32_vlq(1) + payload += encode_uint32_vlq(LegitimationId.SERVER_SESSION_RESPONSE) + payload += bytes([0x10, DT.USINT]) # flags=0x10 (array) + payload += encode_uint32_vlq(len(response)) + payload += response + payload += struct.pack(">I", 0) + + resp_payload = await self._send_request(FunctionCode.SET_VARIABLE, bytes(payload)) + + if len(resp_payload) >= 1: + return_value, _ = decode_uint64_vlq(resp_payload, 0) + if return_value < 0: + from ..error import S7ConnectionError + + raise S7ConnectionError(f"Legacy legitimation rejected by PLC: return_value={return_value}") + logger.debug(f"Legacy legitimation return_value={return_value}") + async def _probe_s7commplus_data(self) -> bool: """Test if the PLC supports S7CommPlus data operations.""" try: @@ -198,6 +464,9 @@ async def disconnect(self) -> None: self._with_integrity_id = False self._integrity_id_read = 0 self._integrity_id_write = 0 + self._tls_active = False + self._oms_secret = None + self._server_session_version = None if self._writer: try: diff --git a/snap7/s7commplus/connection.py b/snap7/s7commplus/connection.py index a60b44a0..1b0a013e 100644 --- a/snap7/s7commplus/connection.py +++ b/snap7/s7commplus/connection.py @@ -1013,7 +1013,11 @@ def _setup_ssl_context( """ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.minimum_version = ssl.TLSVersion.TLSv1_3 - ctx.set_ciphers("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + + # TLS 1.3 ciphersuites are configured differently from TLS 1.2 + if hasattr(ctx, "set_ciphersuites"): + ctx.set_ciphersuites("TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256") + # If set_ciphersuites not available, TLS 1.3 uses its mandatory defaults if cert_path and key_path: ctx.load_cert_chain(cert_path, key_path) diff --git a/tests/test_async_tls.py b/tests/test_async_tls.py new file mode 100644 index 00000000..39393928 --- /dev/null +++ b/tests/test_async_tls.py @@ -0,0 +1,236 @@ +"""Tests for S7CommPlus async client TLS support.""" + +import struct +import tempfile +import time +from collections.abc import Generator + +import pytest + +from snap7.error import S7ConnectionError +from snap7.s7commplus.async_client import S7CommPlusAsyncClient +from snap7.s7commplus.server import S7CommPlusServer +from snap7.s7commplus.protocol import ProtocolVersion + +TEST_PORT_V2 = 11130 +TEST_PORT_V2_TLS = 11131 + + +def _generate_self_signed_cert() -> tuple[str, str]: + """Generate a self-signed certificate and key for testing. + + Returns: + Tuple of (cert_path, key_path) + """ + try: + from cryptography import x509 + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + import datetime + except ImportError: + pytest.skip("cryptography package required for TLS tests") + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .add_extension( + x509.SubjectAlternativeName([x509.IPAddress(ipaddress.IPv4Address("127.0.0.1"))]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + cert_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) + cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) + cert_file.close() + + key_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) + key_file.write(key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.NoEncryption(), + )) + key_file.close() + + return cert_file.name, key_file.name + + +import ipaddress # noqa: E402 + + +class TestAsyncClientTLSPreconditions: + """Test authenticate() and TLS precondition checks.""" + + @pytest.mark.asyncio + async def test_authenticate_not_connected(self) -> None: + client = S7CommPlusAsyncClient() + with pytest.raises(S7ConnectionError, match="Not connected"): + await client.authenticate("password") + + @pytest.mark.asyncio + async def test_authenticate_no_tls(self) -> None: + client = S7CommPlusAsyncClient() + client._connected = True + client._tls_active = False + client._oms_secret = None + with pytest.raises(S7ConnectionError, match="requires TLS"): + await client.authenticate("password") + + def test_tls_active_default_false(self) -> None: + client = S7CommPlusAsyncClient() + assert client.tls_active is False + assert client.oms_secret is None + + @pytest.mark.asyncio + async def test_disconnect_resets_tls_state(self) -> None: + client = S7CommPlusAsyncClient() + client._tls_active = True + client._oms_secret = b"\x00" * 32 + await client.disconnect() + assert client.tls_active is False + assert client.oms_secret is None + + +class TestAsyncClientConnectTLSParams: + """Test that connect() accepts TLS parameters.""" + + @pytest.mark.asyncio + async def test_connect_signature_accepts_tls_params(self) -> None: + """Verify the connect method signature includes TLS params.""" + import inspect + + sig = inspect.signature(S7CommPlusAsyncClient.connect) + params = list(sig.parameters.keys()) + assert "use_tls" in params + assert "tls_cert" in params + assert "tls_key" in params + assert "tls_ca" in params + + +@pytest.fixture() +def v2_server() -> Generator[S7CommPlusServer, None, None]: + """Start a V2 server without TLS for protocol negotiation tests.""" + srv = S7CommPlusServer(protocol_version=ProtocolVersion.V2) + srv.register_raw_db(1, bytearray(256)) + db1 = srv.get_db(1) + assert db1 is not None + struct.pack_into(">f", db1.data, 0, 99.9) + srv.start(port=TEST_PORT_V2) + time.sleep(0.1) + yield srv + srv.stop() + + +class TestAsyncClientV2WithoutTLS: + """Test that V2 connection without TLS raises appropriate error.""" + + @pytest.mark.asyncio + async def test_v2_without_tls_raises(self, v2_server: S7CommPlusServer) -> None: + """V2 protocol requires TLS — connecting without should raise.""" + client = S7CommPlusAsyncClient() + with pytest.raises(S7ConnectionError, match="V2.*requires TLS"): + await client.connect("127.0.0.1", port=TEST_PORT_V2) + + +try: + import cryptography # noqa: F401 + + _has_cryptography = True +except ImportError: + _has_cryptography = False + + +@pytest.mark.skipif(not _has_cryptography, reason="requires cryptography package") +class TestAsyncClientV2WithTLS: + """Test async client with V2 + TLS against emulated server.""" + + @pytest.fixture() + def tls_server(self) -> Generator[tuple[S7CommPlusServer, str, str], None, None]: + """Start a V2 TLS server with self-signed cert.""" + cert_path, key_path = _generate_self_signed_cert() + + srv = S7CommPlusServer(protocol_version=ProtocolVersion.V2) + srv.register_raw_db(1, bytearray(256)) + + db1 = srv.get_db(1) + assert db1 is not None + struct.pack_into(">f", db1.data, 0, 42.0) + + srv.start(port=TEST_PORT_V2_TLS, use_tls=True, tls_cert=cert_path, tls_key=key_path) + time.sleep(0.1) + + yield srv, cert_path, key_path + + srv.stop() + + import os + + os.unlink(cert_path) + os.unlink(key_path) + + @pytest.mark.asyncio + async def test_connect_with_tls(self, tls_server: tuple[S7CommPlusServer, str, str]) -> None: + """Connect to V2 server with TLS enabled.""" + srv, cert_path, key_path = tls_server + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=TEST_PORT_V2_TLS, use_tls=True, tls_ca=cert_path) + + try: + assert client.connected + assert client.tls_active + assert client.protocol_version == ProtocolVersion.V2 + finally: + await client.disconnect() + + @pytest.mark.asyncio + async def test_integrity_id_tracking_enabled(self, tls_server: tuple[S7CommPlusServer, str, str]) -> None: + """V2 connection should enable IntegrityId tracking.""" + srv, cert_path, key_path = tls_server + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=TEST_PORT_V2_TLS, use_tls=True, tls_ca=cert_path) + + try: + assert client._with_integrity_id is True + # Counters may already be non-zero from the probe request + assert client._integrity_id_read >= 0 + assert client._integrity_id_write >= 0 + finally: + await client.disconnect() + + @pytest.mark.asyncio + async def test_protocol_version_is_v2(self, tls_server: tuple[S7CommPlusServer, str, str]) -> None: + """V2 server should report protocol version 2.""" + srv, cert_path, key_path = tls_server + + client = S7CommPlusAsyncClient() + await client.connect("127.0.0.1", port=TEST_PORT_V2_TLS, use_tls=True, tls_ca=cert_path) + + try: + assert client.protocol_version == ProtocolVersion.V2 + finally: + await client.disconnect() + + @pytest.mark.asyncio + async def test_context_manager_tls(self, tls_server: tuple[S7CommPlusServer, str, str]) -> None: + """TLS connection via context manager.""" + srv, cert_path, key_path = tls_server + + async with S7CommPlusAsyncClient() as client: + await client.connect("127.0.0.1", port=TEST_PORT_V2_TLS, use_tls=True, tls_ca=cert_path) + assert client.connected + assert client.tls_active + + assert not client.connected + assert not client.tls_active From 3a990a83b00a053cdcb01c38867718a783d71fcf Mon Sep 17 00:00:00 2001 From: Gijs Molenaar Date: Tue, 24 Mar 2026 19:01:47 +0200 Subject: [PATCH 2/2] Fix ruff formatting Co-Authored-By: Claude Opus 4.6 --- snap7/s7commplus/async_client.py | 4 +--- tests/test_async_tls.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/snap7/s7commplus/async_client.py b/snap7/s7commplus/async_client.py index 08cca180..4ace2241 100644 --- a/snap7/s7commplus/async_client.py +++ b/snap7/s7commplus/async_client.py @@ -179,9 +179,7 @@ async def connect( if not self._tls_active: from ..error import S7ConnectionError - raise S7ConnectionError( - "PLC reports V2 protocol but TLS is not active. V2 requires TLS. Use use_tls=True." - ) + raise S7ConnectionError("PLC reports V2 protocol but TLS is not active. V2 requires TLS. Use use_tls=True.") # Enable IntegrityId tracking for V2+ self._with_integrity_id = True self._integrity_id_read = 0 diff --git a/tests/test_async_tls.py b/tests/test_async_tls.py index 39393928..8c1bd007 100644 --- a/tests/test_async_tls.py +++ b/tests/test_async_tls.py @@ -32,9 +32,11 @@ def _generate_self_signed_cert() -> tuple[str, str]: pytest.skip("cryptography package required for TLS tests") key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), - ]) + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ] + ) cert = ( x509.CertificateBuilder() .subject_name(subject) @@ -55,11 +57,13 @@ def _generate_self_signed_cert() -> tuple[str, str]: cert_file.close() key_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) - key_file.write(key.private_bytes( - serialization.Encoding.PEM, - serialization.PrivateFormat.TraditionalOpenSSL, - serialization.NoEncryption(), - )) + key_file.write( + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.NoEncryption(), + ) + ) key_file.close() return cert_file.name, key_file.name