Skip to content

Commit 8f8b385

Browse files
committed
use moto for tests
1 parent 23429f5 commit 8f8b385

2 files changed

Lines changed: 93 additions & 118 deletions

File tree

api/import_export/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ def __exit__(
6161
logger.warning("Aborted multipart upload due to error: %s", exc_val)
6262
return
6363

64-
# Upload any remaining data in the buffer
65-
if self._buffer.tell() > 0:
64+
# Upload any remaining data in the buffer (or an empty part if no data)
65+
# S3 requires at least one part to complete a multipart upload
66+
if self._buffer.tell() > 0 or not self._parts:
6667
self._upload_part()
6768

6869
assert self._upload_id
Lines changed: 90 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,36 @@
1-
from unittest.mock import MagicMock, call
2-
3-
import pytest
1+
import boto3
2+
from moto import mock_s3 # type: ignore[import-untyped]
3+
from pytest_mock import MockerFixture
44

55
from import_export.utils import S3MultipartUploadWriter
66

77

8-
@pytest.fixture
9-
def s3_client() -> MagicMock:
10-
client = MagicMock()
11-
client.create_multipart_upload.return_value = {"UploadId": "test-upload-id"}
12-
client.upload_part.return_value = {"ETag": "test-etag"}
13-
return client
14-
15-
16-
def test_s3_multipart_upload_writer__single_part__completes_upload(
17-
s3_client: MagicMock,
18-
) -> None:
8+
@mock_s3 # type: ignore[misc]
9+
def test_s3_multipart_upload_writer__single_part__completes_upload() -> None:
1910
# Given
2011
bucket_name = "test-bucket"
2112
key = "test-key"
2213
data = b"small data"
2314

15+
s3_resource = boto3.resource("s3", region_name="eu-west-2")
16+
s3_resource.create_bucket(
17+
Bucket=bucket_name,
18+
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
19+
)
20+
s3_client = boto3.client("s3", region_name="eu-west-2")
21+
2422
# When
2523
with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer:
2624
writer.write(data)
2725

2826
# Then
29-
s3_client.create_multipart_upload.assert_called_once_with(
30-
Bucket=bucket_name,
31-
Key=key,
32-
)
33-
s3_client.upload_part.assert_called_once_with(
34-
Bucket=bucket_name,
35-
Key=key,
36-
PartNumber=1,
37-
UploadId="test-upload-id",
38-
Body=data,
39-
)
40-
s3_client.complete_multipart_upload.assert_called_once_with(
41-
Bucket=bucket_name,
42-
Key=key,
43-
UploadId="test-upload-id",
44-
MultipartUpload={"Parts": [{"PartNumber": 1, "ETag": "test-etag"}]},
45-
)
46-
s3_client.abort_multipart_upload.assert_not_called()
27+
result = s3_client.get_object(Bucket=bucket_name, Key=key)
28+
assert result["Body"].read() == data
4729

4830

31+
@mock_s3 # type: ignore[misc]
4932
def test_s3_multipart_upload_writer__multiple_parts__uploads_each_part(
50-
s3_client: MagicMock,
33+
mocker: MockerFixture,
5134
) -> None:
5235
# Given
5336
bucket_name = "test-bucket"
@@ -58,11 +41,13 @@ def test_s3_multipart_upload_writer__multiple_parts__uploads_each_part(
5841
second_chunk = b"b" * chunk_size
5942
final_chunk = b"final"
6043

61-
s3_client.upload_part.side_effect = [
62-
{"ETag": "etag-1"},
63-
{"ETag": "etag-2"},
64-
{"ETag": "etag-3"},
65-
]
44+
s3_resource = boto3.resource("s3", region_name="eu-west-2")
45+
s3_resource.create_bucket(
46+
Bucket=bucket_name,
47+
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
48+
)
49+
s3_client = boto3.client("s3", region_name="eu-west-2")
50+
upload_part_spy = mocker.spy(s3_client, "upload_part")
6651

6752
# When
6853
with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer:
@@ -71,134 +56,123 @@ def test_s3_multipart_upload_writer__multiple_parts__uploads_each_part(
7156
writer.write(final_chunk)
7257

7358
# Then
74-
assert s3_client.upload_part.call_count == 3
75-
s3_client.upload_part.assert_has_calls(
76-
[
77-
call(
78-
Bucket=bucket_name,
79-
Key=key,
80-
PartNumber=1,
81-
UploadId="test-upload-id",
82-
Body=first_chunk,
83-
),
84-
call(
85-
Bucket=bucket_name,
86-
Key=key,
87-
PartNumber=2,
88-
UploadId="test-upload-id",
89-
Body=second_chunk,
90-
),
91-
call(
92-
Bucket=bucket_name,
93-
Key=key,
94-
PartNumber=3,
95-
UploadId="test-upload-id",
96-
Body=final_chunk,
97-
),
98-
]
99-
)
100-
s3_client.complete_multipart_upload.assert_called_once_with(
101-
Bucket=bucket_name,
102-
Key=key,
103-
UploadId="test-upload-id",
104-
MultipartUpload={
105-
"Parts": [
106-
{"PartNumber": 1, "ETag": "etag-1"},
107-
{"PartNumber": 2, "ETag": "etag-2"},
108-
{"PartNumber": 3, "ETag": "etag-3"},
109-
]
110-
},
111-
)
59+
result = s3_client.get_object(Bucket=bucket_name, Key=key)
60+
assert result["Body"].read() == first_chunk + second_chunk + final_chunk
61+
62+
# Verify exactly 3 parts were uploaded
63+
assert upload_part_spy.call_count == 3
64+
# Verify part numbers are sequential
65+
part_numbers = [
66+
call.kwargs["PartNumber"] for call in upload_part_spy.call_args_list
67+
]
68+
assert part_numbers == [1, 2, 3]
11269

11370

114-
def test_s3_multipart_upload_writer__accumulates_small_writes__uploads_when_threshold_reached(
115-
s3_client: MagicMock,
71+
@mock_s3 # type: ignore[misc]
72+
def test_s3_multipart_upload_writer__accumulates_small_writes__uploads_correctly(
73+
mocker: MockerFixture,
11674
) -> None:
11775
# Given
11876
bucket_name = "test-bucket"
11977
key = "test-key"
12078
small_chunk = b"x" * 1000 # 1KB
12179
writes_to_reach_threshold = (S3MultipartUploadWriter.MIN_PART_SIZE // 1000) + 1
12280

123-
s3_client.upload_part.side_effect = [
124-
{"ETag": "etag-1"},
125-
{"ETag": "etag-2"},
126-
]
81+
s3_resource = boto3.resource("s3", region_name="eu-west-2")
82+
s3_resource.create_bucket(
83+
Bucket=bucket_name,
84+
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
85+
)
86+
s3_client = boto3.client("s3", region_name="eu-west-2")
87+
upload_part_spy = mocker.spy(s3_client, "upload_part")
12788

12889
# When
12990
with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer:
13091
for _ in range(writes_to_reach_threshold):
13192
writer.write(small_chunk)
132-
# Write one more small chunk to have remaining data
13393
writer.write(b"final")
13494

13595
# Then
136-
# Should have uploaded one part when threshold was reached,
137-
# and one final part with remaining data on exit
138-
assert s3_client.upload_part.call_count == 2
96+
result = s3_client.get_object(Bucket=bucket_name, Key=key)
97+
expected_data = (small_chunk * writes_to_reach_threshold) + b"final"
98+
assert result["Body"].read() == expected_data
13999

100+
# Verify buffering: one part when threshold reached, one final part on exit
101+
assert upload_part_spy.call_count == 2
140102

141-
def test_s3_multipart_upload_writer__error_during_write__aborts_upload(
142-
s3_client: MagicMock,
143-
) -> None:
103+
104+
@mock_s3 # type: ignore[misc]
105+
def test_s3_multipart_upload_writer__error_during_write__aborts_upload() -> None:
144106
# Given
145107
bucket_name = "test-bucket"
146108
key = "test-key"
147109

110+
s3_resource = boto3.resource("s3", region_name="eu-west-2")
111+
s3_resource.create_bucket(
112+
Bucket=bucket_name,
113+
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
114+
)
115+
s3_client = boto3.client("s3", region_name="eu-west-2")
116+
148117
# When
149-
with pytest.raises(ValueError, match="test error"):
118+
try:
150119
with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer:
151120
writer.write(b"some data")
152121
raise ValueError("test error")
122+
except ValueError:
123+
pass
153124

154-
# Then
155-
s3_client.abort_multipart_upload.assert_called_once_with(
156-
Bucket=bucket_name,
157-
Key=key,
158-
UploadId="test-upload-id",
159-
)
160-
s3_client.complete_multipart_upload.assert_not_called()
125+
# Then - the object should not exist (upload was aborted)
126+
objects = s3_client.list_objects_v2(Bucket=bucket_name)
127+
assert objects.get("KeyCount", 0) == 0
161128

162129

163-
def test_s3_multipart_upload_writer__no_data__completes_with_no_parts(
164-
s3_client: MagicMock,
165-
) -> None:
130+
@mock_s3 # type: ignore[misc]
131+
def test_s3_multipart_upload_writer__no_data__completes_with_empty_object() -> None:
166132
# Given
167133
bucket_name = "test-bucket"
168134
key = "test-key"
169135

136+
s3_resource = boto3.resource("s3", region_name="eu-west-2")
137+
s3_resource.create_bucket(
138+
Bucket=bucket_name,
139+
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
140+
)
141+
s3_client = boto3.client("s3", region_name="eu-west-2")
142+
170143
# When
171144
with S3MultipartUploadWriter(s3_client, bucket_name, key):
172145
pass # No writes
173146

174147
# Then
175-
s3_client.upload_part.assert_not_called()
176-
s3_client.complete_multipart_upload.assert_called_once_with(
177-
Bucket=bucket_name,
178-
Key=key,
179-
UploadId="test-upload-id",
180-
MultipartUpload={"Parts": []},
181-
)
148+
result = s3_client.get_object(Bucket=bucket_name, Key=key)
149+
assert result["Body"].read() == b""
182150

183151

152+
@mock_s3 # type: ignore[misc]
184153
def test_s3_multipart_upload_writer__multiple_small_writes__buffers_correctly(
185-
s3_client: MagicMock,
154+
mocker: MockerFixture,
186155
) -> None:
187156
# Given
188157
bucket_name = "test-bucket"
189158
key = "test-key"
190159

160+
s3_resource = boto3.resource("s3", region_name="eu-west-2")
161+
s3_resource.create_bucket(
162+
Bucket=bucket_name,
163+
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
164+
)
165+
s3_client = boto3.client("s3", region_name="eu-west-2")
166+
upload_part_spy = mocker.spy(s3_client, "upload_part")
167+
191168
# When
192169
with S3MultipartUploadWriter(s3_client, bucket_name, key) as writer:
193170
writer.write(b"hello ")
194171
writer.write(b"world")
195172

196173
# Then
197-
# Both writes should be buffered and uploaded as single part
198-
s3_client.upload_part.assert_called_once_with(
199-
Bucket=bucket_name,
200-
Key=key,
201-
PartNumber=1,
202-
UploadId="test-upload-id",
203-
Body=b"hello world",
204-
)
174+
result = s3_client.get_object(Bucket=bucket_name, Key=key)
175+
assert result["Body"].read() == b"hello world"
176+
177+
# Verify both writes were buffered and uploaded as a single part
178+
assert upload_part_spy.call_count == 1

0 commit comments

Comments
 (0)