Skip to content

Commit 1be9972

Browse files
authored
Merge pull request #48 from gdcc/add-native-progress
Fix progress bar on native uploads and proxy test implementation
2 parents 1b1397d + 00e3d2a commit 1be9972

8 files changed

Lines changed: 160 additions & 65 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,10 @@ jobs:
66
build:
77
runs-on: ubuntu-latest
88

9-
services:
10-
squid:
11-
image: ubuntu/squid:latest
12-
ports:
13-
- 3128:3128
14-
159
strategy:
1610
max-parallel: 4
1711
matrix:
18-
python-version: ["3.8", "3.9", "3.10", "3.11"]
12+
python-version: ['3.8', '3.9', '3.10', '3.11']
1913

2014
env:
2115
PORT: 8080

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,24 @@ export DVUPLOADER_TESTING=true
210210

211211
**3. Run the test(s) with pytest**
212212

213+
Run all tests:
214+
213215
```bash
214216
poetry run pytest
215217
```
216218

219+
Run a specific test:
220+
221+
```bash
222+
poetry run pytest -k test_native_upload_with_large_file
223+
```
224+
225+
Run all non-expensive tests:
226+
227+
```bash
228+
poetry run pytest -m "not expensive"
229+
```
230+
217231
### Linting
218232

219233
This repository uses `ruff` to lint the code and `codespell` to check for spelling mistakes. You can run the linters with the following command:

dvuploader/dvuploader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def upload(
105105
persistent_id=persistent_id,
106106
api_token=api_token,
107107
replace_existing=replace_existing,
108+
proxy=proxy,
108109
)
109110

110111
# Sort files by size
@@ -146,6 +147,7 @@ def upload(
146147
n_parallel_uploads=n_parallel_uploads,
147148
progress=progress,
148149
pbars=pbars,
150+
proxy=proxy,
149151
)
150152
)
151153
else:
@@ -159,6 +161,7 @@ def upload(
159161
pbars=pbars,
160162
progress=progress,
161163
n_parallel_uploads=n_parallel_uploads,
164+
proxy=proxy,
162165
)
163166
)
164167

@@ -196,7 +199,8 @@ def _check_duplicates(
196199
persistent_id: str,
197200
api_token: str,
198201
replace_existing: bool,
199-
):
202+
proxy: Optional[str] = None,
203+
) -> None:
200204
"""
201205
Checks for duplicate files in the dataset by comparing paths and filenames.
202206
@@ -205,7 +209,7 @@ def _check_duplicates(
205209
persistent_id (str): The persistent ID of the dataset.
206210
api_token (str): The API token for accessing the Dataverse repository.
207211
replace_existing (bool): Whether to replace files that already exist.
208-
212+
proxy (Optional[str]): The proxy to use for the request.
209213
Returns:
210214
None
211215
"""
@@ -214,6 +218,7 @@ def _check_duplicates(
214218
dataverse_url=dataverse_url,
215219
persistent_id=persistent_id,
216220
api_token=api_token,
221+
proxy=proxy,
217222
)
218223

219224
table = Table(
@@ -252,7 +257,7 @@ def _check_duplicates(
252257
# calculate checksum
253258
file.update_checksum_chunked()
254259
file.apply_checksum()
255-
file._unchanged_data = self._check_hashes(file, ds_file)
260+
file._unchanged_data = self._check_hashes(file, ds_file) # type: ignore
256261
if file._unchanged_data:
257262
table.add_row(
258263
file.file_name,

dvuploader/nativeupload.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tempfile
55
from io import BytesIO
66
from pathlib import Path
7-
from typing import Dict, List, Optional, Tuple
7+
from typing import IO, Dict, List, Optional, Tuple
88

99
import httpx
1010
import rich
@@ -65,6 +65,36 @@
6565
ZIP_LIMIT_MESSAGE = "The number of files in the zip archive is over the limit"
6666

6767

68+
class _ProgressFileWrapper:
69+
"""
70+
Wrap a binary file-like object and update a rich progress bar on reads.
71+
httpx's multipart expects a synchronous file-like object exposing .read().
72+
"""
73+
74+
def __init__(
75+
self,
76+
file: IO[bytes],
77+
progress: Progress,
78+
pbar: TaskID,
79+
chunk_size: int = 1024 * 1024,
80+
):
81+
self._file = file
82+
self._progress = progress
83+
self._pbar = pbar
84+
self._chunk_size = chunk_size
85+
86+
def read(self, size: int = -1) -> bytes:
87+
if size is None or size < 0:
88+
size = self._chunk_size
89+
data = self._file.read(size)
90+
if data:
91+
self._progress.update(self._pbar, advance=len(data))
92+
return data
93+
94+
def __getattr__(self, name):
95+
return getattr(self._file, name)
96+
97+
6898
init_logging()
6999

70100

@@ -161,6 +191,7 @@ async def native_upload(
161191
persistent_id=persistent_id,
162192
dataverse_url=dataverse_url,
163193
api_token=api_token,
194+
proxy=proxy,
164195
)
165196

166197

@@ -255,7 +286,9 @@ def _reset_progress(
255286
@tenacity.retry(
256287
wait=RETRY_STRAT,
257288
stop=tenacity.stop_after_attempt(MAX_RETRIES),
258-
retry=tenacity.retry_if_exception_type((httpx.HTTPStatusError,)),
289+
retry=tenacity.retry_if_exception_type(
290+
(httpx.HTTPStatusError, httpx.ReadError, httpx.RequestError)
291+
),
259292
)
260293
async def _single_native_upload(
261294
session: httpx.AsyncClient,
@@ -301,10 +334,12 @@ async def _single_native_upload(
301334
json_data = _get_json_data(file)
302335
handler = file.get_handler()
303336

337+
assert handler is not None, "File handler is required for native upload"
338+
304339
files = {
305340
"file": (
306341
file.file_name,
307-
handler,
342+
_ProgressFileWrapper(handler, progress, pbar), # type: ignore[arg-type]
308343
file.mimeType,
309344
),
310345
"jsonData": (
@@ -316,7 +351,7 @@ async def _single_native_upload(
316351

317352
response = await session.post(
318353
endpoint,
319-
files=files, # type: ignore
354+
files=files,
320355
)
321356

322357
if response.status_code == 400 and response.json()["message"].startswith(
@@ -371,6 +406,7 @@ async def _update_metadata(
371406
dataverse_url: str,
372407
api_token: str,
373408
persistent_id: str,
409+
proxy: Optional[str],
374410
):
375411
"""
376412
Updates the metadata of the given files in a Dataverse repository.
@@ -390,6 +426,7 @@ async def _update_metadata(
390426
persistent_id=persistent_id,
391427
dataverse_url=dataverse_url,
392428
api_token=api_token,
429+
proxy=proxy,
393430
)
394431

395432
tasks = []
@@ -505,6 +542,7 @@ def _retrieve_file_ids(
505542
persistent_id: str,
506543
dataverse_url: str,
507544
api_token: str,
545+
proxy: Optional[str] = None,
508546
) -> Dict[str, str]:
509547
"""
510548
Retrieves the file IDs of files in a dataset.
@@ -513,7 +551,7 @@ def _retrieve_file_ids(
513551
persistent_id (str): The persistent identifier of the dataset.
514552
dataverse_url (str): The URL of the Dataverse repository.
515553
api_token (str): The API token of the Dataverse repository.
516-
554+
proxy (str): The proxy to use for the request.
517555
Returns:
518556
Dict[str, str]: Dictionary mapping file paths to their IDs.
519557
"""
@@ -523,6 +561,7 @@ def _retrieve_file_ids(
523561
persistent_id=persistent_id,
524562
dataverse_url=dataverse_url,
525563
api_token=api_token,
564+
proxy=proxy,
526565
)
527566

528567
return _create_file_id_path_mapping(ds_files)

dvuploader/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pathlib
55
import re
66
import time
7-
from typing import List
7+
from typing import List, Optional
88
from urllib.parse import urljoin
99

1010
import httpx
@@ -59,6 +59,7 @@ def retrieve_dataset_files(
5959
dataverse_url: str,
6060
persistent_id: str,
6161
api_token: str,
62+
proxy: Optional[str] = None,
6263
):
6364
"""
6465
Retrieve the files of a specific dataset from a Dataverse repository.
@@ -67,6 +68,7 @@ def retrieve_dataset_files(
6768
dataverse_url (str): The base URL of the Dataverse repository.
6869
persistent_id (str): The persistent identifier (PID) of the dataset.
6970
api_token (str): API token for authentication.
71+
proxy (Optional[str]): The proxy to use for the request.
7072
7173
Returns:
7274
list: A list of files in the dataset.
@@ -80,6 +82,7 @@ def retrieve_dataset_files(
8082
response = httpx.get(
8183
urljoin(dataverse_url, DATASET_ENDPOINT),
8284
headers={"X-Dataverse-key": api_token},
85+
proxy=proxy,
8386
)
8487

8588
response.raise_for_status()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ ipywidgets = "^8.1.1"
3434
pytest-cov = "^4.1.0"
3535
pytest-asyncio = "^0.23.3"
3636
pytest-httpx = "^0.35.0"
37+
"proxy.py" = "^2.4.4"
3738

3839
[tool.poetry.group.linting.dependencies]
3940
codespell = "^2.2.6"

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def create_dataset(
6060
response = httpx.post(
6161
url=url,
6262
headers={"X-Dataverse-key": api_token},
63-
data=open("./tests/fixtures/create_dataset.json", "rb"), # type: ignore
63+
data=open("./tests/fixtures/create_dataset.json", "rb"), # type: ignore[reportUnboundVariable]
6464
)
6565

6666
response.raise_for_status()

0 commit comments

Comments
 (0)