Skip to content

Commit 60047da

Browse files
gijzelaerrclaude
andauthored
Accept memoryview in setter and getter type annotations (gijzelaerr#647)
The setter functions already worked with memoryview at runtime (using direct struct.pack() slice assignment), but the type annotations only accepted bytearray. This caused mypy errors when passing memoryview objects from ctypes buffers. - Add Buffer type alias (Union[bytearray, memoryview]) to setters and getters - Update all function signatures to accept Buffer - Fix .decode() calls in getters to use bytes() for memoryview compat - Add 17 memoryview compatibility tests Credit: LuTiFlekSSer for identifying the memoryview compatibility issue. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 67713e2 commit 60047da

File tree

4 files changed

+203
-58
lines changed

4 files changed

+203
-58
lines changed

snap7/util/db.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,9 @@ def get_value(self, byte_index: Union[str, int], type_: str) -> ValueType:
635635
return type_to_func[type_](bytearray_, byte_index)
636636
raise ValueError
637637

638-
def set_value(self, byte_index: Union[str, int], type_: str, value: Union[bool, str, float]) -> Optional[bytearray]:
638+
def set_value(
639+
self, byte_index: Union[str, int], type_: str, value: Union[bool, str, float]
640+
) -> Optional[Union[bytearray, memoryview]]:
639641
"""Sets the value for a specific type in the specified byte index.
640642
641643
Args:

snap7/util/getters.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import struct
22
from datetime import timedelta, datetime, date
3-
from typing import NoReturn
3+
from typing import NoReturn, Union
44
from logging import getLogger
55

6+
#: Buffer types accepted by getter functions.
7+
#: Both :class:`bytearray` and :class:`memoryview` are supported.
8+
Buffer = Union[bytearray, memoryview]
9+
610
logger = getLogger(__name__)
711

812

9-
def get_bool(bytearray_: bytearray, byte_index: int, bool_index: int) -> bool:
13+
def get_bool(bytearray_: Buffer, byte_index: int, bool_index: int) -> bool:
1014
"""Get the boolean value from location in bytearray
1115
1216
Args:
@@ -28,7 +32,7 @@ def get_bool(bytearray_: bytearray, byte_index: int, bool_index: int) -> bool:
2832
return current_value == index_value
2933

3034

31-
def get_byte(bytearray_: bytearray, byte_index: int) -> bytes:
35+
def get_byte(bytearray_: Buffer, byte_index: int) -> bytes:
3236
"""Get byte value from bytearray.
3337
3438
Notes:
@@ -48,7 +52,7 @@ def get_byte(bytearray_: bytearray, byte_index: int) -> bytes:
4852
return value
4953

5054

51-
def get_word(bytearray_: bytearray, byte_index: int) -> bytearray:
55+
def get_word(bytearray_: Buffer, byte_index: int) -> bytearray:
5256
"""Get word value from bytearray.
5357
5458
Notes:
@@ -73,7 +77,7 @@ def get_word(bytearray_: bytearray, byte_index: int) -> bytearray:
7377
return value
7478

7579

76-
def get_int(bytearray_: bytearray, byte_index: int) -> int:
80+
def get_int(bytearray_: Buffer, byte_index: int) -> int:
7781
"""Get int value from bytearray.
7882
7983
Notes:
@@ -98,7 +102,7 @@ def get_int(bytearray_: bytearray, byte_index: int) -> int:
98102
return value
99103

100104

101-
def get_uint(bytearray_: bytearray, byte_index: int) -> int:
105+
def get_uint(bytearray_: Buffer, byte_index: int) -> int:
102106
"""Get unsigned int value from bytearray.
103107
104108
Notes:
@@ -121,7 +125,7 @@ def get_uint(bytearray_: bytearray, byte_index: int) -> int:
121125
return int(get_word(bytearray_, byte_index))
122126

123127

124-
def get_real(bytearray_: bytearray, byte_index: int) -> float:
128+
def get_real(bytearray_: Buffer, byte_index: int) -> float:
125129
"""Get real value.
126130
127131
Notes:
@@ -145,7 +149,7 @@ def get_real(bytearray_: bytearray, byte_index: int) -> float:
145149
return real
146150

147151

148-
def get_fstring(bytearray_: bytearray, byte_index: int, max_length: int, remove_padding: bool = True) -> str:
152+
def get_fstring(bytearray_: Buffer, byte_index: int, max_length: int, remove_padding: bool = True) -> str:
149153
"""Parse space-padded fixed-length string from bytearray
150154
151155
Notes:
@@ -176,7 +180,7 @@ def get_fstring(bytearray_: bytearray, byte_index: int, max_length: int, remove_
176180
return string
177181

178182

179-
def get_string(bytearray_: bytearray, byte_index: int) -> str:
183+
def get_string(bytearray_: Buffer, byte_index: int) -> str:
180184
"""Parse string from bytearray
181185
182186
Notes:
@@ -210,7 +214,7 @@ def get_string(bytearray_: bytearray, byte_index: int) -> str:
210214
return "".join(data)
211215

212216

213-
def get_dword(bytearray_: bytearray, byte_index: int) -> int:
217+
def get_dword(bytearray_: Buffer, byte_index: int) -> int:
214218
"""Gets the dword from the buffer.
215219
216220
Notes:
@@ -235,7 +239,7 @@ def get_dword(bytearray_: bytearray, byte_index: int) -> int:
235239
return dword
236240

237241

238-
def get_dint(bytearray_: bytearray, byte_index: int) -> int:
242+
def get_dint(bytearray_: Buffer, byte_index: int) -> int:
239243
"""Get dint value from bytearray.
240244
241245
Notes:
@@ -262,7 +266,7 @@ def get_dint(bytearray_: bytearray, byte_index: int) -> int:
262266
return dint
263267

264268

265-
def get_udint(bytearray_: bytearray, byte_index: int) -> int:
269+
def get_udint(bytearray_: Buffer, byte_index: int) -> int:
266270
"""Get unsigned dint value from bytearray.
267271
268272
Notes:
@@ -289,7 +293,7 @@ def get_udint(bytearray_: bytearray, byte_index: int) -> int:
289293
return dint
290294

291295

292-
def get_s5time(bytearray_: bytearray, byte_index: int) -> str:
296+
def get_s5time(bytearray_: Buffer, byte_index: int) -> str:
293297
micro_to_milli = 1000
294298
data_bytearray = bytearray_[byte_index : byte_index + 2]
295299
s5time_data_int_like = list(data_bytearray.hex())
@@ -315,7 +319,7 @@ def get_s5time(bytearray_: bytearray, byte_index: int) -> str:
315319
return "".join(str(s5time))
316320

317321

318-
def get_dt(bytearray_: bytearray, byte_index: int) -> str:
322+
def get_dt(bytearray_: Buffer, byte_index: int) -> str:
319323
"""Get DATE_AND_TIME Value from bytearray as ISO 8601 formatted Date String
320324
Notes:
321325
Datatype `DATE_AND_TIME` consists in 8 bytes in the PLC.
@@ -331,7 +335,7 @@ def get_dt(bytearray_: bytearray, byte_index: int) -> str:
331335
return get_date_time_object(bytearray_, byte_index).isoformat(timespec="microseconds")
332336

333337

334-
def get_date_time_object(bytearray_: bytearray, byte_index: int) -> datetime:
338+
def get_date_time_object(bytearray_: Buffer, byte_index: int) -> datetime:
335339
"""Get DATE_AND_TIME Value from bytearray as python datetime object
336340
Notes:
337341
Datatype `DATE_AND_TIME` consists in 8 bytes in the PLC.
@@ -364,7 +368,7 @@ def bcd_to_byte(byte: int) -> int:
364368
return datetime(year, month, day, hour, min_, sec, microsec)
365369

366370

367-
def get_time(bytearray_: bytearray, byte_index: int) -> str:
371+
def get_time(bytearray_: Buffer, byte_index: int) -> str:
368372
"""Get time value from bytearray.
369373
370374
Notes:
@@ -408,7 +412,7 @@ def get_time(bytearray_: bytearray, byte_index: int) -> str:
408412
return time_str
409413

410414

411-
def get_usint(bytearray_: bytearray, byte_index: int) -> int:
415+
def get_usint(bytearray_: Buffer, byte_index: int) -> int:
412416
"""Get the unsigned small int from the bytearray
413417
414418
Notes:
@@ -434,7 +438,7 @@ def get_usint(bytearray_: bytearray, byte_index: int) -> int:
434438
return value
435439

436440

437-
def get_sint(bytearray_: bytearray, byte_index: int) -> int:
441+
def get_sint(bytearray_: Buffer, byte_index: int) -> int:
438442
"""Get the small int
439443
440444
Notes:
@@ -460,7 +464,7 @@ def get_sint(bytearray_: bytearray, byte_index: int) -> int:
460464
return value
461465

462466

463-
def get_lint(bytearray_: bytearray, byte_index: int) -> int:
467+
def get_lint(bytearray_: Buffer, byte_index: int) -> int:
464468
"""Get the long int
465469
466470
THIS VALUE IS NEITHER TESTED NOR VERIFIED BY A REAL PLC AT THE MOMENT
@@ -490,7 +494,7 @@ def get_lint(bytearray_: bytearray, byte_index: int) -> int:
490494
return int(lint)
491495

492496

493-
def get_lreal(bytearray_: bytearray, byte_index: int) -> float:
497+
def get_lreal(bytearray_: Buffer, byte_index: int) -> float:
494498
"""Get the long real
495499
496500
Datatype `lreal` (long real) consists in 8 bytes in the PLC.
@@ -515,7 +519,7 @@ def get_lreal(bytearray_: bytearray, byte_index: int) -> float:
515519
return float(struct.unpack_from(">d", bytearray_, offset=byte_index)[0])
516520

517521

518-
def get_lword(bytearray_: bytearray, byte_index: int) -> int:
522+
def get_lword(bytearray_: Buffer, byte_index: int) -> int:
519523
"""Get the long word
520524
521525
Notes:
@@ -540,7 +544,7 @@ def get_lword(bytearray_: bytearray, byte_index: int) -> int:
540544
return lword
541545

542546

543-
def get_ulint(bytearray_: bytearray, byte_index: int) -> int:
547+
def get_ulint(bytearray_: Buffer, byte_index: int) -> int:
544548
"""Get ulint value from bytearray.
545549
546550
Notes:
@@ -565,7 +569,7 @@ def get_ulint(bytearray_: bytearray, byte_index: int) -> int:
565569
return lint
566570

567571

568-
def get_tod(bytearray_: bytearray, byte_index: int) -> timedelta:
572+
def get_tod(bytearray_: Buffer, byte_index: int) -> timedelta:
569573
len_bytearray_ = len(bytearray_)
570574
byte_range = byte_index + 4
571575
if len_bytearray_ < byte_range:
@@ -576,7 +580,7 @@ def get_tod(bytearray_: bytearray, byte_index: int) -> timedelta:
576580
return time_val
577581

578582

579-
def get_date(bytearray_: bytearray, byte_index: int = 0) -> date:
583+
def get_date(bytearray_: Buffer, byte_index: int = 0) -> date:
580584
len_bytearray_ = len(bytearray_)
581585
byte_range = byte_index + 2
582586
if len_bytearray_ < byte_range:
@@ -587,7 +591,7 @@ def get_date(bytearray_: bytearray, byte_index: int = 0) -> date:
587591
return date_val
588592

589593

590-
def get_ltime(bytearray_: bytearray, byte_index: int) -> timedelta:
594+
def get_ltime(bytearray_: Buffer, byte_index: int) -> timedelta:
591595
"""Get LTIME value from bytearray.
592596
593597
Notes:
@@ -612,7 +616,7 @@ def get_ltime(bytearray_: bytearray, byte_index: int) -> timedelta:
612616
return timedelta(microseconds=nanoseconds // 1000)
613617

614618

615-
def get_ltod(bytearray_: bytearray, byte_index: int) -> timedelta:
619+
def get_ltod(bytearray_: Buffer, byte_index: int) -> timedelta:
616620
"""Get LTOD (Long Time of Day) value from bytearray.
617621
618622
Notes:
@@ -635,7 +639,7 @@ def get_ltod(bytearray_: bytearray, byte_index: int) -> timedelta:
635639
return result
636640

637641

638-
def get_ldt(bytearray_: bytearray, byte_index: int) -> datetime:
642+
def get_ldt(bytearray_: Buffer, byte_index: int) -> datetime:
639643
"""Get LDT (Long Date and Time) value from bytearray.
640644
641645
Notes:
@@ -655,7 +659,7 @@ def get_ldt(bytearray_: bytearray, byte_index: int) -> datetime:
655659
return epoch + timedelta(microseconds=nanoseconds // 1000)
656660

657661

658-
def get_dtl(bytearray_: bytearray, byte_index: int) -> datetime:
662+
def get_dtl(bytearray_: Buffer, byte_index: int) -> datetime:
659663
time_to_datetime = datetime(
660664
year=int.from_bytes(bytearray_[byte_index : byte_index + 2], byteorder="big"),
661665
month=int(bytearray_[byte_index + 2]),
@@ -670,7 +674,7 @@ def get_dtl(bytearray_: bytearray, byte_index: int) -> datetime:
670674
return time_to_datetime
671675

672676

673-
def get_char(bytearray_: bytearray, byte_index: int) -> str:
677+
def get_char(bytearray_: Buffer, byte_index: int) -> str:
674678
"""Get char value from bytearray.
675679
676680
Notes:
@@ -694,7 +698,7 @@ def get_char(bytearray_: bytearray, byte_index: int) -> str:
694698
return char
695699

696700

697-
def get_wchar(bytearray_: bytearray, byte_index: int) -> str:
701+
def get_wchar(bytearray_: Buffer, byte_index: int) -> str:
698702
"""Get wchar value from bytearray.
699703
700704
Datatype `wchar` in the PLC is represented in 2 bytes. It has to be in utf-16-be format.
@@ -715,10 +719,10 @@ def get_wchar(bytearray_: bytearray, byte_index: int) -> str:
715719
"""
716720
if bytearray_[byte_index] == 0:
717721
return chr(bytearray_[byte_index + 1])
718-
return bytearray_[byte_index : byte_index + 2].decode("utf-16-be")
722+
return bytes(bytearray_[byte_index : byte_index + 2]).decode("utf-16-be")
719723

720724

721-
def get_wstring(bytearray_: bytearray, byte_index: int) -> str:
725+
def get_wstring(bytearray_: Buffer, byte_index: int) -> str:
722726
"""Parse wstring from bytearray
723727
724728
Notes:
@@ -759,8 +763,8 @@ def get_wstring(bytearray_: bytearray, byte_index: int) -> str:
759763
f"expected or is larger than 16382. Bytearray doesn't seem to be a valid string."
760764
)
761765

762-
return bytearray_[wstring_start : wstring_start + wstr_symbols_amount].decode("utf-16-be")
766+
return bytes(bytearray_[wstring_start : wstring_start + wstr_symbols_amount]).decode("utf-16-be")
763767

764768

765-
def get_array(bytearray_: bytearray, byte_index: int) -> NoReturn:
769+
def get_array(bytearray_: Buffer, byte_index: int) -> NoReturn:
766770
raise NotImplementedError

0 commit comments

Comments
 (0)