Skip to content

Commit 7f8fc8c

Browse files
committed
Harden request URL and upload handling
1 parent 1722c93 commit 7f8fc8c

7 files changed

Lines changed: 136 additions & 39 deletions

File tree

tests/test_async_client.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,27 @@ async def test_async_client_normalizes_service_and_rejects_missing_ids(self):
365365
with self.assertRaises(ValueError):
366366
await client.cancel_assembly()
367367

368+
with self.assertRaises(ValueError):
369+
await client.get_assembly(assembly_url="https://example.com/assemblies/abc123")
370+
371+
with self.assertRaises(ValueError):
372+
await client.cancel_assembly(assembly_url="https://example.com/assemblies/abc123")
373+
374+
transloadit_session = _RecordingSession({"ok": "ASSEMBLY_COMPLETED"})
375+
transloadit_client = AsyncTransloadit(
376+
"key",
377+
"secret",
378+
service="https://api2.transloadit.com",
379+
session=transloadit_session,
380+
)
381+
await transloadit_client.get_assembly(
382+
assembly_url="https://api2-region.transloadit.com/assemblies/abc123"
383+
)
384+
self.assertEqual(
385+
transloadit_session.calls[0][0],
386+
"https://api2-region.transloadit.com/assemblies/abc123",
387+
)
388+
368389
await client.close()
369390

370391
self.assertFalse(session.closed)
@@ -564,7 +585,7 @@ async def test_async_assembly_wait_raises_on_plain_text_poll_response(self):
564585
)
565586
sleep_mock.assert_awaited_once_with(0)
566587

567-
async def test_async_assembly_wait_returns_plain_text_poll_response(self):
588+
async def test_async_assembly_wait_raises_on_plain_text_success_poll_response(self):
568589
initial_response = Response(
569590
data={
570591
"ok": "ASSEMBLY_PROCESSING",
@@ -586,9 +607,9 @@ async def test_async_assembly_wait_returns_plain_text_poll_response(self):
586607
with mock.patch.object(client.request, "post", new=mock.AsyncMock(return_value=initial_response)) as post_mock:
587608
with mock.patch.object(client, "get_assembly", new=mock.AsyncMock(return_value=plain_response)) as get_mock:
588609
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as sleep_mock:
589-
response = await assembly.create(wait=True, resumable=False)
610+
with self.assertRaises(RuntimeError):
611+
await assembly.create(wait=True, resumable=False)
590612

591-
self.assertIs(response, plain_response)
592613
post_mock.assert_awaited_once()
593614
get_mock.assert_awaited_once_with(
594615
assembly_url=f"{self.server.base_url}/assemblies/assembly-123"
@@ -1313,7 +1334,7 @@ async def test_async_assembly_rate_limit_ignores_malformed_error_values(self):
13131334
self.assertFalse(assembly._rate_limit_reached({"error": ["RATE_LIMIT_REACHED"]}))
13141335
self.assertFalse(assembly._rate_limit_reached({"error": {"code": "RATE_LIMIT_REACHED"}}))
13151336

1316-
async def test_async_tus_upload_cancellation_waits_for_thread_to_finish(self):
1337+
async def test_async_tus_upload_cancellation_returns_before_thread_finishes(self):
13171338
client = AsyncTransloadit("key", "secret", service=self.server.base_url)
13181339
assembly = client.new_assembly()
13191340
started = threading.Event()
@@ -1338,13 +1359,14 @@ def blocking_upload(assembly_url, tus_url, retries):
13381359
upload_task.cancel()
13391360
await asyncio.sleep(0.05)
13401361

1341-
self.assertFalse(upload_task.done())
1362+
self.assertTrue(upload_task.done())
13421363
self.assertFalse(finished.is_set())
13431364

1344-
release.set()
13451365
with self.assertRaises(asyncio.CancelledError):
13461366
await upload_task
13471367

1368+
release.set()
1369+
await asyncio.to_thread(finished.wait, 5)
13481370
self.assertTrue(finished.is_set())
13491371

13501372
async def test_async_request_uses_connect_and_read_timeouts_for_uploads(self):
@@ -1406,7 +1428,7 @@ async def test_async_request_filters_none_and_lowercases_booleans_in_extra_data(
14061428
response = await client.request.post(
14071429
"/assemblies",
14081430
data={"foo": "bar"},
1409-
extra_data={"enabled": True, "skip": None},
1431+
extra_data={"enabled": True, "skip": None, "tags": ["a", "b"]},
14101432
files={"file": upload},
14111433
)
14121434

@@ -1415,6 +1437,21 @@ async def test_async_request_filters_none_and_lowercases_booleans_in_extra_data(
14151437
self.assertIn("enabled", fields)
14161438
self.assertNotIn("skip", fields)
14171439
self.assertEqual(fields["enabled"][2], "true")
1440+
tag_values = [field[2] for field in session.calls[0][1]["data"]._fields if field[0]["name"] == "tags"]
1441+
self.assertEqual(tag_values, ["a", "b"])
1442+
1443+
def test_non_closing_upload_stream_reflects_seekability(self):
1444+
class _NonSeekableUpload(io.BytesIO):
1445+
def seekable(self):
1446+
return False
1447+
1448+
class _BrokenSeekableUpload(io.BytesIO):
1449+
def seekable(self):
1450+
raise OSError("seekable failed")
1451+
1452+
self.assertTrue(_NonClosingUploadStream(io.BytesIO(b"payload")).seekable())
1453+
self.assertFalse(_NonClosingUploadStream(_NonSeekableUpload(b"payload")).seekable())
1454+
self.assertFalse(_NonClosingUploadStream(_BrokenSeekableUpload(b"payload")).seekable())
14181455

14191456
async def test_async_request_uses_filename_fallback_for_trailing_slash_stream_name(self):
14201457
session = _RecordingSession({"ok": "ASSEMBLY_COMPLETED"})

tests/test_client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,19 @@ def test_get_assembly(self, mock):
7272
self.assertEqual(response.data["ok"], "ASSEMBLY_COMPLETED")
7373
self.assertEqual(response.data["assembly_id"], "abcdef12345")
7474

75+
def test_quotes_path_ids(self):
76+
with mock.patch.object(self.transloadit.request, 'get') as get_mock:
77+
self.transloadit.get_assembly(assembly_id='assembly/with?chars')
78+
self.transloadit.get_template('template/with?chars')
79+
80+
self.assertEqual(
81+
get_mock.call_args_list,
82+
[
83+
mock.call('/assemblies/assembly%2Fwith%3Fchars'),
84+
mock.call('/templates/template%2Fwith%3Fchars'),
85+
],
86+
)
87+
7588
@requests_mock.Mocker()
7689
def test_list_assemblies(self, mock):
7790
url = f"{self.transloadit.service}/assemblies"

tests/test_request.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def test_payload_preserves_custom_auth_constraints(self):
5656
self.assertEqual(params["auth"]["max_size"], 1024)
5757
self.assertEqual(params["auth"]["referer"], "https://example.com")
5858

59+
def test_full_url_rejects_external_absolute_urls(self):
60+
self.assertEqual(
61+
self.request._get_full_url(f"{self.transloadit.service}/foo"),
62+
f"{self.transloadit.service}/foo",
63+
)
64+
self.assertEqual(
65+
self.request._get_full_url("https://api2-region.transloadit.com/foo"),
66+
"https://api2-region.transloadit.com/foo",
67+
)
68+
with self.assertRaises(ValueError):
69+
self.request._get_full_url("https://example.com/foo")
70+
5971
@requests_mock.Mocker()
6072
def test_put(self, mock):
6173
url = f"{self.transloadit.service}/foo"

transloadit/async_assembly.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,7 @@ def _do_tus_upload(self, assembly_url, tus_url, retries):
7878
).upload()
7979

8080
async def _do_tus_upload_async(self, assembly_url, tus_url, retries):
81-
upload_task = asyncio.create_task(
82-
asyncio.to_thread(self._do_tus_upload, assembly_url, tus_url, retries)
83-
)
84-
try:
85-
await asyncio.shield(upload_task)
86-
except asyncio.CancelledError:
87-
try:
88-
await asyncio.shield(upload_task)
89-
finally:
90-
raise
81+
await asyncio.to_thread(self._do_tus_upload, assembly_url, tus_url, retries)
9182

9283
async def create(self, wait=False, resumable=True, retries=3):
9384
"""
@@ -111,10 +102,8 @@ async def create(self, wait=False, resumable=True, retries=3):
111102

112103
response_data = self._response_data(response)
113104
if response_data is None:
114-
if response.status_code >= 400:
105+
if response.status_code >= 400 or wait or (resumable and self.files):
115106
raise RuntimeError(f"Unexpected non-JSON response ({response.status_code}).")
116-
if resumable and self.files:
117-
raise RuntimeError("Resumable assembly response is missing upload URLs.")
118107
return response
119108

120109
if self._rate_limit_reached(response_data):
@@ -163,9 +152,7 @@ async def create(self, wait=False, resumable=True, retries=3):
163152
)
164153
poll_data = self._response_data(poll_response)
165154
if poll_data is None:
166-
if poll_response.status_code >= 400:
167-
raise RuntimeError(f"Unexpected non-JSON response ({poll_response.status_code}).")
168-
return poll_response
155+
raise RuntimeError(f"Unexpected non-JSON response ({poll_response.status_code}).")
169156

170157
return poll_response
171158

transloadit/async_request.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
from types import MappingProxyType
1010
from datetime import datetime, timedelta, timezone
11+
from urllib.parse import urlparse
1112

1213
import aiohttp
1314
from requests.structures import CaseInsensitiveDict
@@ -18,6 +19,10 @@
1819
TIMEOUT = 60
1920

2021

22+
def _is_transloadit_host(hostname):
23+
return hostname == "transloadit.com" or hostname.endswith(".transloadit.com")
24+
25+
2126
def _get_upload_filename(file_stream, fallback):
2227
name = getattr(file_stream, "name", None)
2328
if isinstance(name, (bytes, os.PathLike)):
@@ -60,6 +65,12 @@ def seek(self, *args):
6065
return self._file_stream.seek(*args)
6166

6267
def seekable(self):
68+
seekable = getattr(self._file_stream, "seekable", None)
69+
if callable(seekable):
70+
try:
71+
return seekable()
72+
except (OSError, ValueError):
73+
return False
6374
return hasattr(self._file_stream, "seek")
6475

6576
def tell(self):
@@ -121,14 +132,18 @@ def _timeout(self, files=False):
121132
)
122133

123134
def _normalize_payload(self, data):
124-
normalized = {}
135+
normalized = []
125136
for key, value in data.items():
126137
if value is None:
127138
continue
128-
if isinstance(value, bool):
129-
normalized[key] = "true" if value else "false"
130-
else:
131-
normalized[key] = str(value)
139+
values = value if isinstance(value, (list, tuple)) else [value]
140+
for item in values:
141+
if item is None:
142+
continue
143+
if isinstance(item, bool):
144+
normalized.append((key, "true" if item else "false"))
145+
else:
146+
normalized.append((key, str(item)))
132147
return normalized
133148

134149
async def _read_response_data(self, response):
@@ -144,9 +159,10 @@ async def get(self, path, params=None):
144159
"""
145160
Makes an asynchronous HTTP GET request.
146161
"""
162+
url = self._get_full_url(path)
147163
session = await self._ensure_session()
148164
async with session.get(
149-
self._get_full_url(path),
165+
url,
150166
params=self._to_payload(params),
151167
headers=self._headers(),
152168
timeout=self._timeout(),
@@ -161,14 +177,15 @@ async def post(self, path, data=None, extra_data=None, files=None):
161177
"""
162178
Makes an asynchronous HTTP POST request.
163179
"""
180+
url = self._get_full_url(path)
164181
session = await self._ensure_session()
165182
data = self._to_payload(data)
166183
if extra_data:
167184
data.update(extra_data)
168185

169186
if files:
170187
form = aiohttp.FormData()
171-
for key, value in self._normalize_payload(data).items():
188+
for key, value in self._normalize_payload(data):
172189
form.add_field(key, value)
173190

174191
for key, file_stream in files.items():
@@ -185,7 +202,7 @@ async def post(self, path, data=None, extra_data=None, files=None):
185202
payload = self._normalize_payload(data)
186203

187204
async with session.post(
188-
self._get_full_url(path),
205+
url,
189206
data=payload,
190207
headers=self._headers(),
191208
timeout=self._timeout(files=bool(files)),
@@ -200,10 +217,11 @@ async def put(self, path, data=None):
200217
"""
201218
Makes an asynchronous HTTP PUT request.
202219
"""
220+
url = self._get_full_url(path)
203221
session = await self._ensure_session()
204222
data = self._normalize_payload(self._to_payload(data))
205223
async with session.put(
206-
self._get_full_url(path),
224+
url,
207225
data=data,
208226
headers=self._headers(),
209227
timeout=self._timeout(),
@@ -218,10 +236,11 @@ async def delete(self, path, data=None):
218236
"""
219237
Makes an asynchronous HTTP DELETE request.
220238
"""
239+
url = self._get_full_url(path)
221240
session = await self._ensure_session()
222241
data = self._normalize_payload(self._to_payload(data))
223242
async with session.delete(
224-
self._get_full_url(path),
243+
url,
225244
data=data,
226245
headers=self._headers(),
227246
timeout=self._timeout(),
@@ -252,5 +271,15 @@ def _sign_data(self, message):
252271

253272
def _get_full_url(self, url):
254273
if url.startswith(("http://", "https://")):
274+
service = urlparse(self.transloadit.service)
275+
target = urlparse(url)
276+
same_origin = (target.scheme, target.netloc) == (service.scheme, service.netloc)
277+
transloadit_origin = (
278+
target.scheme == service.scheme
279+
and _is_transloadit_host(service.hostname or "")
280+
and _is_transloadit_host(target.hostname or "")
281+
)
282+
if not (same_origin or transloadit_origin):
283+
raise ValueError("Absolute API URLs must use the configured Transloadit service origin.")
255284
return url
256285
return self.transloadit.service + url

transloadit/client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import hmac
33
import hashlib
44
import time
5-
from urllib.parse import urlencode, quote_plus
5+
from urllib.parse import quote, quote_plus, urlencode
66

77
from typing import Optional, Union, List
88

@@ -18,6 +18,10 @@ def _stringify_url_param(value: Union[str, int, float, bool]) -> str:
1818
return str(value)
1919

2020

21+
def _quote_path_segment(value: str) -> str:
22+
return quote(str(value), safe="")
23+
24+
2125
class Transloadit:
2226
"""
2327
This class serves as a client interface to the Transloadit API.
@@ -77,7 +81,7 @@ def get_assembly(self, assembly_id: str = None, assembly_url: str = None):
7781
if not (assembly_id or assembly_url):
7882
raise ValueError("Either 'assembly_id' or 'assembly_url' cannot be None.")
7983

80-
url = assembly_url if assembly_url else f"/assemblies/{assembly_id}"
84+
url = assembly_url if assembly_url else f"/assemblies/{_quote_path_segment(assembly_id)}"
8185
return self.request.get(url)
8286

8387
def list_assemblies(self, params: dict = None):
@@ -107,7 +111,7 @@ def cancel_assembly(self, assembly_id: str = None, assembly_url: str = None):
107111
if not (assembly_id or assembly_url):
108112
raise ValueError("Either 'assembly_id' or 'assembly_url' cannot be None.")
109113

110-
url = assembly_url if assembly_url else f"/assemblies/{assembly_id}"
114+
url = assembly_url if assembly_url else f"/assemblies/{_quote_path_segment(assembly_id)}"
111115
return self.request.delete(url)
112116

113117
def get_template(self, template_id: str):
@@ -119,7 +123,7 @@ def get_template(self, template_id: str):
119123
120124
Return an instance of <transloadit.response.Response>
121125
"""
122-
return self.request.get(f"/templates/{template_id}")
126+
return self.request.get(f"/templates/{_quote_path_segment(template_id)}")
123127

124128
def list_templates(self, params: Optional[dict] = None):
125129
"""
@@ -154,7 +158,7 @@ def update_template(self, template_id: str, data: dict):
154158
155159
Return an instance of <transloadit.response.Response>
156160
"""
157-
return self.request.put(f"/templates/{template_id}", data=data)
161+
return self.request.put(f"/templates/{_quote_path_segment(template_id)}", data=data)
158162

159163
def delete_template(self, template_id: str):
160164
"""
@@ -165,7 +169,7 @@ def delete_template(self, template_id: str):
165169
166170
Return an instance of <transloadit.response.Response>
167171
"""
168-
return self.request.delete(f"/templates/{template_id}")
172+
return self.request.delete(f"/templates/{_quote_path_segment(template_id)}")
169173

170174
def get_bill(self, month: int, year: int):
171175
"""

0 commit comments

Comments
 (0)