-
Notifications
You must be signed in to change notification settings - Fork 74
Expand file tree
/
Copy pathstart_batch_job_orchestration.py
More file actions
191 lines (176 loc) · 7.85 KB
/
start_batch_job_orchestration.py
File metadata and controls
191 lines (176 loc) · 7.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import argparse
import asyncio
import os
from datetime import timedelta
import aioredis
from model_engine_server.api.dependencies import get_monitoring_metrics_gateway
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.model_endpoints import BrokerType
from model_engine_server.common.env_vars import CIRCLECI
from model_engine_server.core.config import infra_config
from model_engine_server.db.base import get_session_async_null_pool
from model_engine_server.domain.entities import BatchJobSerializationFormat
from model_engine_server.domain.gateways import TaskQueueGateway
from model_engine_server.infra.gateways import (
ABSFilesystemGateway,
ASBInferenceAutoscalingMetricsGateway,
CeleryTaskQueueGateway,
LiveAsyncModelEndpointInferenceGateway,
LiveBatchJobProgressGateway,
LiveModelEndpointInfraGateway,
LiveModelEndpointsSchemaGateway,
LiveStreamingModelEndpointInferenceGateway,
LiveSyncModelEndpointInferenceGateway,
RedisInferenceAutoscalingMetricsGateway,
S3FilesystemGateway,
)
from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import (
ASBQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import (
LiveEndpointResourceGateway,
)
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import (
SQSQueueEndpointResourceDelegate,
)
from model_engine_server.infra.repositories import (
DbBatchJobRecordRepository,
DbModelEndpointRecordRepository,
RedisModelEndpointCacheRepository,
)
from model_engine_server.infra.services import (
LiveBatchJobOrchestrationService,
LiveModelEndpointService,
)
async def run_batch_job(
job_id: str,
owner: str,
input_path: str,
serialization_format: BatchJobSerializationFormat,
timeout_seconds: float,
):
session = get_session_async_null_pool()
pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url)
redis = aioredis.Redis(connection_pool=pool)
sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS)
gcppubsub_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.GCPPUBSUB)
servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS)
monitoring_metrics_gateway = get_monitoring_metrics_gateway()
model_endpoint_record_repo = DbModelEndpointRecordRepository(
monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False
)
queue_delegate: QueueEndpointResourceDelegate
if CIRCLECI:
queue_delegate = FakeQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
)
inference_autoscaling_metrics_gateway = (
ASBInferenceAutoscalingMetricsGateway()
if infra_config().cloud_provider == "azure"
else RedisInferenceAutoscalingMetricsGateway(redis_client=redis)
)
resource_gateway = LiveEndpointResourceGateway(
queue_delegate=queue_delegate,
inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway,
)
inference_task_queue_gateway: TaskQueueGateway
infra_task_queue_gateway: TaskQueueGateway
if infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
inference_task_queue_gateway = gcppubsub_task_queue_gateway
infra_task_queue_gateway = gcppubsub_task_queue_gateway
else:
inference_task_queue_gateway = sqs_task_queue_gateway
infra_task_queue_gateway = sqs_task_queue_gateway
model_endpoint_infra_gateway = LiveModelEndpointInfraGateway(
resource_gateway=resource_gateway,
task_queue_gateway=infra_task_queue_gateway,
)
model_endpoint_cache_repo = RedisModelEndpointCacheRepository(
redis_client=redis,
)
async_model_endpoint_inference_gateway = LiveAsyncModelEndpointInferenceGateway(
task_queue_gateway=inference_task_queue_gateway
)
streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway(
monitoring_metrics_gateway=monitoring_metrics_gateway,
use_asyncio=(not CIRCLECI),
)
sync_model_endpoint_inference_gateway = LiveSyncModelEndpointInferenceGateway(
monitoring_metrics_gateway=monitoring_metrics_gateway,
use_asyncio=(not CIRCLECI),
)
filesystem_gateway = (
ABSFilesystemGateway()
if infra_config().cloud_provider == "azure"
else S3FilesystemGateway()
)
model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
filesystem_gateway=filesystem_gateway
)
model_endpoint_service = LiveModelEndpointService(
model_endpoint_record_repository=model_endpoint_record_repo,
model_endpoint_infra_gateway=model_endpoint_infra_gateway,
model_endpoint_cache_repository=model_endpoint_cache_repo,
async_model_endpoint_inference_gateway=async_model_endpoint_inference_gateway,
streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway,
sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway,
model_endpoints_schema_gateway=model_endpoints_schema_gateway,
inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway,
can_scale_http_endpoint_from_zero_flag=False, # shouldn't matter since we only use this to create async endpoints
)
batch_job_record_repository = DbBatchJobRecordRepository(session=session, read_only=False)
batch_job_progress_gateway = LiveBatchJobProgressGateway(filesystem_gateway=filesystem_gateway)
batch_job_orchestration_service = LiveBatchJobOrchestrationService(
model_endpoint_service=model_endpoint_service,
batch_job_record_repository=batch_job_record_repository,
batch_job_progress_gateway=batch_job_progress_gateway,
async_model_endpoint_inference_gateway=async_model_endpoint_inference_gateway,
filesystem_gateway=filesystem_gateway,
)
await batch_job_orchestration_service.run_batch_job(
job_id=job_id,
owner=owner,
input_path=input_path,
serialization_format=serialization_format,
timeout=timedelta(seconds=timeout_seconds),
)
def entrypoint():
parser = argparse.ArgumentParser()
parser.add_argument("--job-id", "-j", required=True, help="The ID of the batch job to run.")
parser.add_argument(
"--owner", "-o", required=True, help="The ID of the user who owns the batch job."
)
parser.add_argument("--input-path", "-i", required=True, help="The path to the input data.")
parser.add_argument(
"--serialization-format",
"-s",
required=True,
help="The serialization format of the input data.",
)
parser.add_argument("--timeout-seconds", "-t", required=True, help="The timeout in seconds.")
args = parser.parse_args()
serialization_fmt = BatchJobSerializationFormat(args.serialization_format)
asyncio.run(
run_batch_job(
job_id=args.job_id,
owner=args.owner,
input_path=args.input_path,
serialization_format=serialization_fmt,
timeout_seconds=float(args.timeout_seconds),
)
)
if __name__ == "__main__":
entrypoint()