-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathmiddleware.py
More file actions
186 lines (158 loc) · 7.33 KB
/
middleware.py
File metadata and controls
186 lines (158 loc) · 7.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from __future__ import annotations
import logging
import time
from starlette.exceptions import HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette import status
from src.platform.isolationEngine.session import SessionManager
from src.platform.isolationEngine.core import CoreIsolationEngine
from src.platform.api.auth import get_principal_id, is_dev_mode
from src.platform.db.schema import RunTimeEnvironment
logger = logging.getLogger(__name__)
class PlatformMiddleware(BaseHTTPMiddleware):
"""Middleware for platform API authentication."""
def __init__(self, app, *, session_manager: SessionManager):
super().__init__(app)
self.session_manager = session_manager
async def dispatch(self, request: Request, call_next) -> Response:
path = request.scope.get("path", "")
if path == "/api/platform/health":
return await call_next(request)
api_key_hdr = request.headers.get("X-API-Key") or request.headers.get(
"Authorization"
)
if not api_key_hdr and not is_dev_mode():
return JSONResponse(
{"detail": "missing api key"},
status_code=status.HTTP_401_UNAUTHORIZED,
)
# Determine action type for rate limiting
action = "api_request"
if path == "/api/platform/initEnv" and request.method == "POST":
action = "environment_created"
try:
principal_id = await get_principal_id(api_key_hdr, action=action)
with self.session_manager.with_meta_session() as meta_session:
request.state.principal_id = principal_id
request.state.db_session = meta_session
return await call_next(request)
except PermissionError as exc:
return JSONResponse(
{"detail": str(exc)},
status_code=status.HTTP_401_UNAUTHORIZED,
)
except RuntimeError as exc:
logger.error(f"Control plane error: {exc}")
return JSONResponse(
{"detail": str(exc)},
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
)
except HTTPException:
raise # Let Starlette handle route-level HTTP errors (e.g. 404)
except Exception:
logger.exception("Unhandled exception in PlatformMiddleware")
return JSONResponse(
{"detail": "internal server error"},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
class IsolationMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
*,
session_manager: SessionManager,
core_isolation_engine: CoreIsolationEngine,
):
super().__init__(app)
self.session_manager = session_manager
self.core_isolation_engine = core_isolation_engine
async def dispatch(self, request: Request, call_next) -> Response:
path = request.scope.get("path", "")
# Expected: /api/env/{env_id}/services/{service}/...
if not path.startswith("/api/env/"):
return await call_next(request)
t_total_start = time.perf_counter()
try:
path_after_prefix = path[len("/api/env/") :]
env_id = path_after_prefix.split("/")[0] if path_after_prefix else ""
if not env_id:
return JSONResponse(
{"ok": False, "error": "invalid_environment_path"},
status_code=status.HTTP_400_BAD_REQUEST,
)
# Box /download paths are reached via 302 redirect from the
# authenticated /content endpoint. HTTP clients strip auth
# headers on redirect (per RFC 9110), so we skip the API-key
# check here — mirroring how real Box returns a pre-signed CDN
# URL that needs no Authorization header.
is_download_redirect = "/download" in path
api_key_hdr = request.headers.get("X-API-Key") or request.headers.get(
"Authorization"
)
if not api_key_hdr and not is_download_redirect and not is_dev_mode():
return JSONResponse(
{"ok": False, "error": "not_authed"},
status_code=status.HTTP_401_UNAUTHORIZED,
)
t_auth_start = time.perf_counter()
if api_key_hdr:
principal_id = await get_principal_id(api_key_hdr, action="api_request")
elif is_download_redirect:
principal_id = "download-redirect"
else:
principal_id = "dev-user"
t_auth_ms = (time.perf_counter() - t_auth_start) * 1000
t_meta_start = time.perf_counter()
with self.session_manager.with_meta_session() as meta_session:
request.state.principal_id = principal_id
try:
env_uuid = self.session_manager._to_uuid(env_id)
env = (
meta_session.query(RunTimeEnvironment)
.filter(RunTimeEnvironment.id == env_uuid)
.one_or_none()
)
if env is not None:
request.state.impersonate_user_id = env.impersonate_user_id
request.state.impersonate_email = env.impersonate_email
except (ValueError, TypeError) as e:
logger.debug(
f"Could not load impersonation data for env {env_id}: {e}"
)
t_meta_ms = (time.perf_counter() - t_meta_start) * 1000
t_handler_start = time.perf_counter()
with self.session_manager.with_session_for_environment(env_id) as session:
request.state.db_session = session
request.state.environment_id = env_id
response = await call_next(request)
t_handler_ms = (time.perf_counter() - t_handler_start) * 1000
t_total_ms = (time.perf_counter() - t_total_start) * 1000
# Extract service from path for easier log filtering
parts = path_after_prefix.split("/")
service_name = parts[2] if len(parts) > 2 else "unknown"
logger.info(
f"[PERF] {request.method} {path} | service={service_name} "
f"total={t_total_ms:.0f}ms auth={t_auth_ms:.0f}ms "
f"meta_db={t_meta_ms:.0f}ms handler={t_handler_ms:.0f}ms "
f"status={response.status_code}"
)
return response
except PermissionError as exc:
return JSONResponse(
{"ok": False, "error": str(exc)},
status_code=status.HTTP_401_UNAUTHORIZED,
)
except RuntimeError as exc:
logger.error(f"Control plane error: {exc}")
return JSONResponse(
{"ok": False, "error": str(exc)},
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
)
except Exception:
logger.exception("Unhandled exception in IsolationMiddleware")
return JSONResponse(
{"ok": False, "error": "internal_error"},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)