-
Notifications
You must be signed in to change notification settings - Fork 622
Expand file tree
/
Copy pathapi_deployment_views.py
More file actions
408 lines (357 loc) · 16.6 KB
/
api_deployment_views.py
File metadata and controls
408 lines (357 loc) · 16.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import json
import logging
import uuid
from typing import Any
from configuration.models import Configuration
from django.db.models import QuerySet
from django.http import HttpResponse
from permissions.co_owner_views import CoOwnerManagementMixin
from permissions.permission import IsOwner, IsOwnerOrSharedUserOrSharedToOrg
from plugins import get_plugin
from prompt_studio.prompt_studio_registry_v2.models import PromptStudioRegistry
from rest_framework import serializers, status, views, viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from tool_instance_v2.models import ToolInstance
from utils.enums import CeleryTaskState
from workflow_manager.workflow_v2.dto import ExecutionResponse
from api_v2.api_deployment_dto_registry import ApiDeploymentDTORegistry
from api_v2.constants import ApiExecution
from api_v2.deployment_helper import DeploymentHelper
from api_v2.dto import DeploymentExecutionDTO
from api_v2.exceptions import (
NoActiveAPIKeyError,
RateLimitExceeded,
contains_tool_not_found_error,
)
from api_v2.models import APIDeployment
from api_v2.rate_limiter import APIDeploymentRateLimiter
from api_v2.serializers import (
APIDeploymentListSerializer,
APIDeploymentSerializer,
DeploymentResponseSerializer,
ExecutionQuerySerializer,
ExecutionRequestSerializer,
SharedUserListSerializer,
)
notification_plugin = get_plugin("notification")
if notification_plugin:
from plugins.notification.constants import ResourceType
logger = logging.getLogger(__name__)
class DeploymentExecution(views.APIView):
def initialize_request(self, request: Request, *args: Any, **kwargs: Any) -> Request:
"""To remove csrf request for public API.
Args:
request (Request): _description_
Returns:
Request: _description_
"""
request.csrf_processing_done = True
return super().initialize_request(request, *args, **kwargs)
@DeploymentHelper.validate_api_key
def post(
self,
request: Request,
org_name: str,
api_name: str,
deployment_execution_dto: DeploymentExecutionDTO,
) -> Response:
api: APIDeployment = deployment_execution_dto.api
api_key: str = deployment_execution_dto.api_key
organization = api.organization
serializer = ExecutionRequestSerializer(
data=request.data, context={"api": api, "api_key": api_key}
)
serializer.is_valid(raise_exception=True)
file_objs = serializer.validated_data.get(ApiExecution.FILES_FORM_DATA, [])
presigned_urls = serializer.validated_data.get(ApiExecution.PRESIGNED_URLS, [])
timeout = serializer.validated_data.get(ApiExecution.TIMEOUT_FORM_DATA)
include_metadata = serializer.validated_data.get(ApiExecution.INCLUDE_METADATA)
include_metrics = serializer.validated_data.get(ApiExecution.INCLUDE_METRICS)
use_file_history = serializer.validated_data.get(ApiExecution.USE_FILE_HISTORY)
tag_names = serializer.validated_data.get(ApiExecution.TAGS)
llm_profile_id = serializer.validated_data.get(ApiExecution.LLM_PROFILE_ID)
hitl_queue_name = serializer.validated_data.get(ApiExecution.HITL_QUEUE_NAME)
hitl_packet_id = serializer.validated_data.get(ApiExecution.HITL_PACKET_ID)
custom_data = serializer.validated_data.get(ApiExecution.CUSTOM_DATA)
if presigned_urls:
DeploymentHelper.load_presigned_files(presigned_urls, file_objs)
# Generate execution ID for rate limiting
execution_id = str(uuid.uuid4())
# Check and acquire rate limit slot
can_proceed, limit_info = APIDeploymentRateLimiter.check_and_acquire(
organization, execution_id
)
if not can_proceed:
logger.warning(
f"Rate limit exceeded for org {organization.organization_id}: {limit_info}"
)
raise RateLimitExceeded(
current_usage=limit_info["current_usage"],
limit=limit_info["limit"],
limit_type=limit_info["limit_type"],
)
# Execute workflow with blanket exception handling
try:
response = DeploymentHelper.execute_workflow(
organization_name=org_name,
api=api,
file_objs=file_objs,
timeout=timeout,
include_metadata=include_metadata,
include_metrics=include_metrics,
use_file_history=use_file_history,
tag_names=tag_names,
llm_profile_id=llm_profile_id,
hitl_queue_name=hitl_queue_name,
hitl_packet_id=hitl_packet_id,
custom_data=custom_data,
request_headers=dict(request.headers),
execution_id=execution_id,
)
except Exception as error:
# Release slot on any failure during workflow setup/execution
APIDeploymentRateLimiter.release_slot(organization, execution_id)
logger.exception(f"Workflow execution failed: {error}")
raise
# Determine response status based on execution result
execution_status = response.get("execution_status", "")
has_error = response.get("error") or execution_status == "ERROR"
if has_error:
# Check for tool not found in registry error - return 500 Internal Server Error
# This is a server-side deployment state issue, not a client-actionable error
if contains_tool_not_found_error(response):
logger.error(
"API deployment failed: Tool not found in container registry"
)
return Response(
{"message": response},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# Other errors - return 422 Unprocessable Entity
logger.error("API deployment execution failed")
return Response(
{"message": response},
status=status.HTTP_422_UNPROCESSABLE_ENTITY,
)
# Success
return Response({"message": response}, status=status.HTTP_200_OK)
@DeploymentHelper.validate_api_key
def get(
self,
request: Request,
org_name: str,
api_name: str,
deployment_execution_dto: DeploymentExecutionDTO,
) -> Response:
serializer = ExecutionQuerySerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
execution_id = serializer.validated_data.get(ApiExecution.EXECUTION_ID)
include_metadata = serializer.validated_data.get(ApiExecution.INCLUDE_METADATA)
include_metrics = serializer.validated_data.get(ApiExecution.INCLUDE_METRICS)
# Fetch execution status
response: ExecutionResponse = DeploymentHelper.get_execution_status(execution_id)
# Handle result already acknowledged
if response.result_acknowledged:
return Response(
data={
"status": response.execution_status,
"message": "Result already acknowledged",
},
status=status.HTTP_406_NOT_ACCEPTABLE,
)
# Determine response status based on execution state
execution_status_value = response.execution_status
# Check for tool not found in registry error - return 500 Internal Server Error
# This is a server-side deployment state issue, not a client-actionable error
if contains_tool_not_found_error(response):
logger.error("Execution failed: Tool not found in container registry")
return Response(
data={
"status": execution_status_value,
"message": response.result,
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# Check for ERROR status - return 422 Unprocessable Entity
if execution_status_value == "ERROR":
return Response(
data={
"status": execution_status_value,
"message": response.result,
},
status=status.HTTP_422_UNPROCESSABLE_ENTITY,
)
# Process completed execution
response_status = status.HTTP_422_UNPROCESSABLE_ENTITY
if execution_status_value == CeleryTaskState.COMPLETED.value:
response_status = status.HTTP_200_OK
# Check if highlight data should be removed using configuration registry
api_deployment = deployment_execution_dto.api
organization = api_deployment.organization if api_deployment else None
enable_highlight = False # Safe default if the key is unavailable (e.g., OSS)
# Check if the configuration key exists (Cloud deployment) or use settings (OSS)
from configuration.config_registry import ConfigurationRegistry
if ConfigurationRegistry.is_config_key_available(
"ENABLE_HIGHLIGHT_API_DEPLOYMENT"
):
enable_highlight = Configuration.get_value_by_organization(
config_key="ENABLE_HIGHLIGHT_API_DEPLOYMENT",
organization=organization,
)
if not enable_highlight:
response.remove_result_metadata_keys(["highlight_data"])
response.remove_result_metadata_keys(["extracted_text"])
if not include_metadata:
response.remove_result_metadata_keys()
if not include_metrics:
response.remove_result_metrics()
return Response(
data={
"status": response.execution_status,
"message": response.result,
},
status=response_status,
)
class APIDeploymentViewSet(CoOwnerManagementMixin, viewsets.ModelViewSet):
notification_resource_name_field = "display_name"
def get_notification_resource_type(self, resource: Any) -> str | None:
from plugins.notification.constants import ResourceType
return ResourceType.API_DEPLOYMENT.value # type: ignore
def get_permissions(self) -> list[Any]:
if self.action in [
"destroy",
"partial_update",
"update",
"add_co_owner",
"remove_co_owner",
]:
return [IsOwner()]
return [IsOwnerOrSharedUserOrSharedToOrg()]
def get_queryset(self) -> QuerySet | None:
queryset = APIDeployment.objects.for_user(self.request.user)
# Filter by workflow ID if provided
workflow_filter = self.request.query_params.get("workflow", None)
if workflow_filter:
queryset = queryset.filter(workflow_id=workflow_filter)
return queryset
def get_serializer_class(self) -> serializers.Serializer:
if self.action in ["list"]:
return APIDeploymentListSerializer
return APIDeploymentSerializer
@action(detail=True, methods=["get"])
def fetch_one(self, request: Request, pk: str | None = None) -> Response:
"""Custom action to fetch a single instance."""
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
def create(
self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any]
) -> Response:
serializer: Serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
api_key = DeploymentHelper.create_api_key(serializer=serializer, request=request)
response_serializer = DeploymentResponseSerializer(
{"api_key": api_key.api_key, **serializer.data}
)
headers = self.get_success_headers(serializer.data)
return Response(
response_serializer.data,
status=status.HTTP_201_CREATED,
headers=headers,
)
@action(detail=False, methods=["get"])
def by_prompt_studio_tool(self, request: Request) -> Response:
"""Get API deployments for a specific prompt studio tool."""
tool_id = request.query_params.get("tool_id")
if not tool_id:
return Response(
{"error": "tool_id parameter is required"},
status=status.HTTP_400_BAD_REQUEST,
)
try:
# Find the prompt studio registry for this custom tool
registry = PromptStudioRegistry.objects.get(custom_tool__tool_id=tool_id)
# Find workflows that contain tool instances with this prompt registry ID
tool_instances = ToolInstance.objects.filter(
tool_id=str(registry.prompt_registry_id)
)
workflow_ids = tool_instances.values_list("workflow_id", flat=True).distinct()
# Get API deployments for these workflows
deployments = APIDeployment.objects.filter(
workflow_id__in=workflow_ids, created_by=request.user
)
serializer = APIDeploymentListSerializer(
deployments, many=True, context={"request": request}
)
return Response(serializer.data, status=status.HTTP_200_OK)
except PromptStudioRegistry.DoesNotExist:
return Response([], status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error fetching API deployments for tool {tool_id}: {e}")
return Response(
{"error": "Failed to fetch API deployments"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@action(detail=True, methods=["get"])
def download_postman_collection(
self, request: Request, pk: str | None = None
) -> Response:
"""Downloads a Postman Collection of the API deployment instance."""
instance = self.get_object()
api_key_inst = instance.api_keys.filter(is_active=True).first()
if not api_key_inst:
logger.error(f"No active API key set for deployment {instance.pk}")
raise NoActiveAPIKeyError(deployment_name=instance.display_name)
dto_class = ApiDeploymentDTORegistry.get_dto()
postman_collection = dto_class.create(
instance=instance, api_key=api_key_inst.api_key
)
response = HttpResponse(
json.dumps(postman_collection.to_dict()), content_type="application/json"
)
response["Content-Disposition"] = (
f'attachment; filename="{instance.display_name}.json"'
)
return response
@action(detail=True, methods=["get"], permission_classes=[IsOwner])
def list_of_shared_users(self, request: Request, pk: str | None = None) -> Response:
"""List users who have access to this API deployment."""
instance = self.get_object()
serializer = SharedUserListSerializer(instance)
return Response(serializer.data)
def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Response:
"""Override partial_update to handle sharing notifications."""
# Get current instance and shared users
instance = self.get_object()
current_shared_users = set(instance.shared_users.all())
# Perform the update
response = super().partial_update(request, *args, **kwargs)
# If successful and shared_users changed, send notifications
if (
response.status_code == 200
and "shared_users" in request.data
and notification_plugin
):
try:
instance.refresh_from_db()
new_shared_users = set(instance.shared_users.all())
newly_shared_users = new_shared_users - current_shared_users
if newly_shared_users:
# Get notification service from plugin
service_class = notification_plugin["service_class"]
notification_service = service_class()
notification_service.send_sharing_notification(
resource_type=ResourceType.API_DEPLOYMENT.value,
resource_name=instance.display_name,
resource_id=str(instance.id),
shared_by=request.user,
shared_to=list(newly_shared_users),
resource_instance=instance,
)
except Exception as e:
logger.exception(f"Failed to send sharing notification: {e}")
return response