Skip to content

Commit 98c0bcb

Browse files
authored
Fix mypy errors in strict mode (#28)
1 parent d7ec0bb commit 98c0bcb

30 files changed

Lines changed: 204 additions & 179 deletions

.pre-commit-config.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ repos:
1616
- id: check-docstring-first
1717
- id: detect-private-key
1818

19-
# - repo: https://github.com/pre-commit/mirrors-mypy
20-
# rev: 'v1.10.1'
21-
# hooks:
22-
# - id: mypy
19+
- repo: https://github.com/pre-commit/mirrors-mypy
20+
rev: 'v1.10.1'
21+
hooks:
22+
- id: mypy
23+
additional_dependencies:
24+
- pytest
2325

2426
- repo: https://github.com/astral-sh/ruff-pre-commit
2527
rev: 'v0.5.1'

src/python_s7comm/async_client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import datetime
2-
from typing import List, Optional, Tuple
32

43
from .s7comm import AsyncS7Comm, enums
54
from .s7comm.packets.variable_address import VariableAddress
@@ -50,31 +49,32 @@ async def get_cpu_state(self) -> enums.CPUStatus:
5049
async def read_area(self, address: str) -> bytes:
5150
return await self.s7comm.read_area(VariableAddress.from_string(address))
5251

53-
async def write_area(self, address: str, data: bytes) -> int:
54-
return await self.s7comm.write_area(address=VariableAddress.from_string(address), data=data)
52+
async def write_area(self, address: str, data: bytes) -> None:
53+
await self.s7comm.write_area(address=VariableAddress.from_string(address), data=data)
5554

56-
async def get_order_code(self) -> Optional[str]:
55+
async def get_order_code(self) -> str | None:
5756
response = await self.s7comm.read_szl(szl_id=0x0011, szl_index=0x0000)
58-
szl_data = SZLResponseData.parse(response.data)
57+
szl_data = SZLResponseData.parse(response.data.data)
5958
for date_tree in szl_data.szl_data_tree_list:
6059
module_identification = ModuleIdentificationDataTree.parse(date_tree)
6160
if module_identification.index == ModuleIdentificationIndex.MODULE_IDENTIFICATION:
6261
return module_identification.order_number
6362
return None
6463

6564
async def read_szl(self, szl_id: int, szl_index: int = 0x0000) -> SZLResponseData:
66-
return await self.s7comm.read_szl(szl_id=szl_id, szl_index=szl_index)
65+
response_data = await self.s7comm.read_szl(szl_id=szl_id, szl_index=szl_index)
66+
return SZLResponseData.parse(response_data.data.data)
6767

6868
async def read_szl_list(self) -> SZLResponseData:
6969
response_data = await self.s7comm.read_szl(szl_id=0x0000, szl_index=0x0000)
7070
return SZLResponseData.parse(response_data.data.data)
7171

72-
async def read_multi_vars(self, items: List[str]) -> List[bytes]:
72+
async def read_multi_vars(self, items: list[str]) -> list[bytes]:
7373
vars_ = [VariableAddress.from_string(item) for item in items]
7474
response = await self.s7comm.read_multi_vars(items=vars_)
7575
return response.values()
7676

77-
async def write_multi_vars(self, items: List[Tuple[str, bytes]]) -> bool:
77+
async def write_multi_vars(self, items: list[tuple[str, bytes]]) -> bool:
7878
vars_ = [(VariableAddress.from_string(address), data) for address, data in items]
7979
response = await self.s7comm.write_multi_vars(items=vars_)
8080
return response.check_result()
@@ -85,7 +85,7 @@ async def set_plc_system_datetime(self) -> int:
8585
async def delete(self, block_type: str, block_num: int) -> int:
8686
raise NotImplementedError
8787

88-
async def full_upload(self, _type: str, block_num: int) -> Tuple[bytearray, int]:
88+
async def full_upload(self, _type: str, block_num: int) -> tuple[bytearray, int]:
8989
raise NotImplementedError
9090

9191
async def upload(self, block_num: int) -> bytearray:
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .async_client import AsyncS7Comm
22
from .client import S7Comm
3+
from .exceptions import PacketLostError, StalePacketError
34

45

5-
__all__ = ["S7Comm", "AsyncS7Comm"]
6+
__all__ = ["S7Comm", "AsyncS7Comm", "PacketLostError", "StalePacketError"]

src/python_s7comm/s7comm/async_client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from typing import cast
55

66
from .base import BaseS7Comm
7-
from .enums import HeaderError, ItemReturnCode, MessageType, SubfunctionCode, UserdataFunction
7+
from .enums import HeaderError, ItemReturnCode, MessageType, SubfunctionCode, UserdataFunction, UserdataLastPDU
88
from .error_messages import ERROR_MESSAGES
99
from .exceptions import PacketLostError, StalePacketError
1010
from .packets import (
1111
RequestPLCStop,
1212
S7AckDataHeader,
1313
S7Packet,
1414
SetupCommunicationRequest,
15+
SZLResponseData,
1516
UserDataContinuationRequest,
1617
UserDataRequest,
1718
UserDataResponse,
@@ -38,6 +39,7 @@ def __init__(
3839
transport: AsyncTransport | None = None,
3940
):
4041
super().__init__(pdu_length=pdu_length)
42+
self.transport: AsyncTransport
4143
if transport is None:
4244
self.transport = AsyncCOTP(tpdu_size=tpdu_size, source_tsap=source_tsap, dest_tsap=dest_tsap)
4345
else:
@@ -132,7 +134,7 @@ async def read_area(self, address: VariableAddress) -> bytes:
132134
result += response_item.data
133135
return result
134136

135-
async def write_area(self, address: VariableAddress, data: bytes) -> int:
137+
async def write_area(self, address: VariableAddress, data: bytes) -> WriteVariableResponse:
136138
"""
137139
Writes data to PLC memory area with automatic splitting into multiple requests
138140
if data exceeds PDU size.
@@ -177,7 +179,7 @@ async def write_multi_vars(self, items: list[tuple[VariableAddress, bytes]]) ->
177179
raise ValueError("Invalid response class")
178180
return response
179181

180-
async def read_szl(self, szl_id: int, szl_index: int) -> S7Packet:
182+
async def read_szl(self, szl_id: int, szl_index: int) -> UserDataResponse:
181183
data = struct.pack("!HH", szl_id, szl_index)
182184
request = UserDataRequest.create(
183185
function_group=UserdataFunction.CPU_FUNCTION,
@@ -186,6 +188,8 @@ async def read_szl(self, szl_id: int, szl_index: int) -> S7Packet:
186188
data=data,
187189
)
188190
response = await self.send(request=request)
191+
if not isinstance(response, UserDataResponse):
192+
raise ValueError("Invalid response class")
189193
return response
190194

191195
async def plc_stop(self) -> S7Packet:

src/python_s7comm/s7comm/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _create_packet(request: S7Packet, pdu_reference: int) -> bytes:
4747

4848
def _validate_pdu_reference(self, response: S7Packet) -> None:
4949
"""Raises if PDU reference is invalid."""
50+
assert response.header is not None
5051
if response.header.pdu_reference > self._pdu_reference:
5152
raise PacketLostError(f"Expected {self._pdu_reference}, got {response.header.pdu_reference}")
5253
elif response.header.pdu_reference < self._pdu_reference:

src/python_s7comm/s7comm/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
transport: Transport | None = None,
3838
) -> None:
3939
super().__init__(pdu_length=pdu_length)
40+
self.transport: Transport
4041
if transport is None:
4142
self.transport = COTP(tpdu_size=tpdu_size, source_tsap=source_tsap, dest_tsap=dest_tsap)
4243
else:

src/python_s7comm/s7comm/packets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
from .data_item import DataItem
12
from .headers import S7AckDataHeader, S7Header
23
from .packet import S7Packet
34
from .plc_command import RequestPLCStop
45
from .rw_variable import VariableReadRequest, VariableWriteRequest
56
from .setup_communication import SetupCommunicationParameter, SetupCommunicationRequest
7+
from .szl import SZLResponseData
68
from .user_data import UserDataContinuationRequest, UserDataRequest, UserDataResponse
79
from .variable_address import VariableAddress
810

911

1012
__all__ = [
13+
"DataItem",
1114
"S7AckDataHeader",
1215
"S7Parameter",
1316
"RequestPLCStop",
1417
"SetupCommunicationRequest",
1518
"SetupCommunicationParameter",
1619
"S7Packet",
20+
"SZLResponseData",
1721
"UserDataContinuationRequest",
1822
"UserDataRequest",
1923
"UserDataResponse",

src/python_s7comm/s7comm/packets/error.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ def serialize_parameter(self) -> bytes:
1616

1717
def serialize_data(self) -> bytes:
1818
return b""
19+
20+
@classmethod
21+
def parse(cls, packet: bytes) -> "S7Error":
22+
raise NotImplementedError("S7Error is created from header, not parsed from bytes")

src/python_s7comm/s7comm/packets/packet.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, ClassVar, Protocol
1+
from typing import Any, ClassVar, Protocol, Self
22

33
from .headers import S7AckDataHeader, S7Header
44

@@ -7,16 +7,23 @@ class S7Data(Protocol):
77
def serialize(self) -> bytes: ...
88

99

10+
class S7Parameter(Protocol):
11+
def serialize(self) -> bytes: ...
12+
13+
1014
class S7Packet(Protocol):
1115
MESSAGE_TYPE: ClassVar
1216
header: S7Header | None
13-
parameter: Any # Optional[S7Parameter]
14-
data: Any # Optional[S7Data] # Union[DataItem, List[DataItem], List[enums.ItemReturnCode], None]
17+
parameter: Any # S7Parameter | None
18+
data: Any # S7Data | list[S7Data] | None
1519

1620
def serialize_parameter(self) -> bytes: ...
1721

1822
def serialize_data(self) -> bytes: ...
1923

24+
@classmethod
25+
def parse(cls, packet: bytes) -> "S7Packet": ...
26+
2027

2128
class S7Response(S7Packet):
2229
def create_packet(self, pdu_reference: int) -> bytes:

src/python_s7comm/s7comm/packets/parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
from .user_data import UserDataResponse
1515

1616

17-
response_code_to_packet: dict[FunctionCode, S7Packet] = {
17+
response_code_to_packet: dict[FunctionCode, type[S7Packet]] = {
1818
FunctionCode.SetupCommunication: SetupCommunicationRequest,
1919
FunctionCode.ReadVariable: ReadVariableResponse,
2020
FunctionCode.WriteVariable: WriteVariableResponse,
2121
}
2222

23-
request_code_to_packet: dict[FunctionCode, S7Packet] = {
23+
request_code_to_packet: dict[FunctionCode, type[S7Packet]] = {
2424
FunctionCode.SetupCommunication: SetupCommunicationRequest,
2525
FunctionCode.ReadVariable: VariableReadRequest,
2626
FunctionCode.WriteVariable: VariableWriteRequest,
@@ -47,12 +47,12 @@ def parse(packet: bytes) -> S7Packet:
4747

4848
# parse parameter
4949
parameter = packet[header.LENGTH :]
50-
response = response_code_to_packet.get(function_code).parse(parameter)
50+
response = response_code_to_packet[function_code].parse(parameter)
5151
elif message_type == MessageType.JobRequest:
5252
header = S7Header.parse(packet)
5353
function_code = struct.unpack_from("!B", packet, header.LENGTH)[0]
5454
parameter = packet[header.LENGTH :]
55-
response = request_code_to_packet.get(function_code).parse(parameter)
55+
response = request_code_to_packet[function_code].parse(parameter)
5656
elif message_type == MessageType.Userdata:
5757
header = S7Header.parse(packet)
5858
response = UserDataResponse.parse(packet[header.LENGTH :])

0 commit comments

Comments
 (0)