-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathrequest.py
More file actions
118 lines (103 loc) · 3.34 KB
/
request.py
File metadata and controls
118 lines (103 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import aiohttp
from typing import Any, Optional
from ..models import ModelDefinition
from ..errors import (
QueueSubmitError,
QueueStatusError,
QueueResultError,
FileTooLargeError,
MAX_FILE_SIZE,
)
from .._user_agent import build_user_agent
from ..process.request import file_input_to_bytes
from .types import JobSubmitResponse, JobStatusResponse
async def submit_job(
session: aiohttp.ClientSession,
base_url: str,
api_key: str,
model: ModelDefinition,
inputs: dict[str, Any],
integration: Optional[str] = None,
) -> JobSubmitResponse:
"""Submit a job to the queue.
POST /v1/jobs/{model}
"""
form_data = aiohttp.FormData()
for key, value in inputs.items():
if value is not None:
if key in ("data", "start", "end", "reference_image"):
content, content_type = await file_input_to_bytes(value, session)
limit = model.max_file_size or MAX_FILE_SIZE
if len(content) > limit:
raise FileTooLargeError(len(content), limit, key)
form_data.add_field(key, content, content_type=content_type)
else:
form_data.add_field(key, str(value))
endpoint = f"{base_url}/v1/jobs/{model.name}"
async with session.post(
endpoint,
headers={
"X-API-KEY": api_key,
"User-Agent": build_user_agent(integration),
},
data=form_data,
) as response:
if not response.ok:
error_text = await response.text()
raise QueueSubmitError(
f"Failed to submit job: {response.status} - {error_text}",
data={"status": response.status},
)
data = await response.json()
return JobSubmitResponse(**data)
async def get_job_status(
session: aiohttp.ClientSession,
base_url: str,
api_key: str,
job_id: str,
integration: Optional[str] = None,
) -> JobStatusResponse:
"""Get the status of a job.
GET /v1/jobs/{job_id}
"""
endpoint = f"{base_url}/v1/jobs/{job_id}"
async with session.get(
endpoint,
headers={
"X-API-KEY": api_key,
"User-Agent": build_user_agent(integration),
},
) as response:
if not response.ok:
error_text = await response.text()
raise QueueStatusError(
f"Failed to get job status: {response.status} - {error_text}",
data={"status": response.status},
)
data = await response.json()
return JobStatusResponse(**data)
async def get_job_content(
session: aiohttp.ClientSession,
base_url: str,
api_key: str,
job_id: str,
integration: Optional[str] = None,
) -> bytes:
"""Get the content/result of a completed job.
GET /v1/jobs/{job_id}/content
"""
endpoint = f"{base_url}/v1/jobs/{job_id}/content"
async with session.get(
endpoint,
headers={
"X-API-KEY": api_key,
"User-Agent": build_user_agent(integration),
},
) as response:
if not response.ok:
error_text = await response.text()
raise QueueResultError(
f"Failed to get job content: {response.status} - {error_text}",
data={"status": response.status},
)
return await response.read()