Skip to content

Commit 52c04a2

Browse files
committed
refactor: use a single configurable timeout value
Replace magic default values to use a single per-client default timeout value. This allows higher layers (e.g. smpmgr) to specify a timeout value when necessary and otherwise rely on the default set when creating the client.
1 parent 74ac9d3 commit 52c04a2

2 files changed

Lines changed: 41 additions & 25 deletions

File tree

smpclient/__init__.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class SMPClient:
6565
Args:
6666
transport: the `SMPTransport` to use
6767
address: the address of the SMP server, see `smpclient.transport` for details
68+
timeout: the default timeout in seconds for SMP requests
6869
6970
Example:
7071
@@ -88,25 +89,30 @@ async def main():
8889
```
8990
"""
9091

91-
def __init__(self, transport: SMPTransport, address: str): # noqa: DOC301
92+
def __init__(
93+
self, transport: SMPTransport, address: str, timeout: float = 60.0
94+
): # noqa: DOC301
9295
self._transport: Final = transport
9396
self._address: Final = address
97+
self._timeout = timeout
9498

95-
async def connect(self, timeout_s: float = 5.0) -> None:
99+
async def connect(self, timeout_s: float | None = None) -> None:
96100
"""Connect to the SMP server.
97101
98102
Args:
99103
timeout_s: the timeout for the connection attempt in seconds
100104
"""
105+
timeout_s = timeout_s if timeout_s is not None else self._timeout
106+
101107
await self._transport.connect(self._address, timeout_s)
102-
await self._initialize()
108+
await self._initialize(timeout_s)
103109

104110
async def disconnect(self) -> None:
105111
"""Disconnect from the SMP server."""
106112
await self._transport.disconnect()
107113

108114
async def request(
109-
self, request: SMPRequest[TRep, TEr1, TEr2], timeout_s: float = 120.000
115+
self, request: SMPRequest[TRep, TEr1, TEr2], timeout_s: float | None = None
110116
) -> TRep | TEr1 | TEr2:
111117
"""Make an `SMPRequest` to the SMP server and return the Response or Error.
112118
@@ -161,6 +167,7 @@ async def request(
161167
```
162168
163169
"""
170+
timeout_s = timeout_s if timeout_s is not None else self._timeout
164171

165172
try:
166173
async with timeout(timeout_s):
@@ -200,8 +207,8 @@ async def upload(
200207
image: bytes,
201208
slot: int = 0,
202209
upgrade: bool = False,
203-
first_timeout_s: float = 40.000,
204-
subsequent_timeout_s: float = 2.500,
210+
first_timeout_s: float = 40.0,
211+
subsequent_timeout_s: float | None = None,
205212
use_sha: bool = True,
206213
) -> AsyncIterator[int]:
207214
"""Iteratively upload an `image` to `slot`, yielding the offset.
@@ -216,6 +223,8 @@ async def upload(
216223
[boot_write_img_confirmed()](https://docs.zephyrproject.org/apidoc/latest/group__mcuboot__api.html#ga95ccc9e1c7460fec16b9ce9ac8ad7a72)
217224
for this purpose.
218225
first_timeout_s: the timeout for the first `ImageUploadWrite` request
226+
which might take longer than subsequent requests (e.g. if a big
227+
chunk of flash memory has to be erased upfront).
219228
subsequent_timeout_s: the timeout for subsequent `ImageUploadWrite` requests
220229
use_sha: `True` to include the SHA256 hash of the image in the first
221230
packet.
@@ -230,6 +239,9 @@ async def upload(
230239
Raises:
231240
SMPUploadError: if the upload routine fails
232241
"""
242+
subsequent_timeout_s = (
243+
subsequent_timeout_s if subsequent_timeout_s is not None else self._timeout
244+
)
233245

234246
response = await self.request(
235247
self._maximize_image_upload_write_packet(
@@ -292,7 +304,7 @@ async def upload_file(
292304
self,
293305
file_data: bytes,
294306
file_path: str,
295-
timeout_s: float = 2.500,
307+
timeout_s: float | None = None,
296308
) -> AsyncIterator[int]:
297309
"""Iteratively upload a `file_data` to `file_path`, yielding the offset.
298310
@@ -307,6 +319,8 @@ async def upload_file(
307319
Raises:
308320
SMPUploadError: if the upload routine fails
309321
"""
322+
timeout_s = timeout_s if timeout_s is not None else self._timeout
323+
310324
response = await self.request(
311325
self._maximize_file_upload_packet(
312326
FileUpload(name=file_path, off=0, data=b"", len=len(file_data)),
@@ -344,7 +358,7 @@ async def upload_file(
344358
async def download_file(
345359
self,
346360
file_path: str,
347-
timeout_s: float = 2.500,
361+
timeout_s: float | None = None,
348362
) -> bytes:
349363
"""Download a file from the SMP server.
350364
@@ -358,6 +372,8 @@ async def download_file(
358372
Raises:
359373
SMPUploadError: if the download routine fails
360374
"""
375+
timeout_s = timeout_s if timeout_s is not None else self._timeout
376+
361377
response = await self.request(FileDownload(off=0, name=file_path), timeout_s=timeout_s)
362378
file_length = 0
363379

@@ -490,18 +506,18 @@ def _maximize_file_upload_packet(self, request: FileUpload, data: bytes) -> File
490506
len=request.len,
491507
)
492508

493-
async def _initialize(self) -> None:
509+
async def _initialize(self, timeout_s: float | None = None) -> None:
494510
"""Gather initialization information from the SMP server."""
511+
timeout_s = timeout_s if timeout_s is not None else self._timeout
495512

496513
try:
497-
async with timeout(2):
498-
mcumgr_parameters = await self.request(MCUMgrParametersRead())
499-
if success(mcumgr_parameters):
500-
logger.debug(f"MCUMgr parameters: {mcumgr_parameters}")
501-
self._transport.initialize(mcumgr_parameters.buf_size)
502-
elif error(mcumgr_parameters):
503-
logger.warning(f"Error reading MCUMgr parameters: {mcumgr_parameters}")
504-
else:
505-
assert_never(mcumgr_parameters)
506-
except asyncio.TimeoutError:
514+
mcumgr_parameters = await self.request(MCUMgrParametersRead(), timeout_s=timeout_s)
515+
if success(mcumgr_parameters):
516+
logger.debug(f"MCUMgr parameters: {mcumgr_parameters}")
517+
self._transport.initialize(mcumgr_parameters.buf_size)
518+
elif error(mcumgr_parameters):
519+
logger.warning(f"Error reading MCUMgr parameters: {mcumgr_parameters}")
520+
else:
521+
assert_never(mcumgr_parameters)
522+
except TimeoutError:
507523
logger.warning("Timeout waiting for MCUMgr parameters")

tests/test_smp_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ def test_constructor() -> None:
8787
@pytest.mark.asyncio
8888
async def test_connect() -> None:
8989
m = SMPMockTransport()
90-
s = SMPClient(m, "address")
90+
s = SMPClient(m, "address", 5.0)
9191
s._initialize = AsyncMock() # type: ignore
9292
await s.connect()
9393

9494
m.connect.assert_awaited_once_with("address", 5.0)
95-
s._initialize.assert_awaited_once_with()
95+
s._initialize.assert_awaited_once_with(5.0)
9696

9797

9898
@pytest.mark.asyncio
@@ -177,7 +177,7 @@ async def test_request() -> None:
177177
@pytest.mark.asyncio
178178
async def test_upload() -> None:
179179
m = SMPMockTransport()
180-
s = SMPClient(m, "address")
180+
s = SMPClient(m, "address", 2.5)
181181

182182
s.request = AsyncMock() # type: ignore
183183

@@ -241,7 +241,7 @@ async def test_upload() -> None:
241241
off=415,
242242
data=image[415 : 415 + 474],
243243
),
244-
timeout_s=2.500,
244+
timeout_s=2.5,
245245
)
246246

247247
# assert that upload() raises SMPUploadError
@@ -388,7 +388,7 @@ async def mock_request(
388388
@pytest.mark.asyncio
389389
async def test_upload_file() -> None:
390390
m = SMPMockTransport()
391-
s = SMPClient(m, "address")
391+
s = SMPClient(m, "address", 2.5)
392392

393393
s.request = AsyncMock() # type: ignore
394394

@@ -619,7 +619,7 @@ async def mock_request(
619619
@pytest.mark.asyncio
620620
async def test_download_file() -> None:
621621
m = SMPMockTransport()
622-
s = SMPClient(m, "address")
622+
s = SMPClient(m, "address", 2.5)
623623

624624
s.request = AsyncMock() # type: ignore
625625

0 commit comments

Comments
 (0)