Skip to content

Commit 05290f0

Browse files
committed
Make filesystem dependency everywhere
1 parent b19a6fc commit 05290f0

10 files changed

Lines changed: 129 additions & 110 deletions

File tree

api/core/filesystem.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import shutil
66
import zipfile
77
from pathlib import Path, PurePosixPath
8-
from typing import Any, BinaryIO, Generator, cast
8+
from typing import Any, BinaryIO, Callable, Generator
99

1010
import boto3
1111
import humanize
@@ -17,7 +17,7 @@
1717
from mypy_boto3_s3 import S3Client
1818
from mypy_boto3_s3.type_defs import ObjectIdentifierTypeDef
1919

20-
from api import models, settings
20+
from api import models
2121
from api.schemas.file import FileHTTPRequest, FileInfo, FileTypes
2222

2323

@@ -387,29 +387,43 @@ def download_url(
387387
)
388388

389389

390-
def get_filesystem_with_root(root_path: str) -> FileSystem:
390+
def get_filesystem_with_root(
391+
root_path: str,
392+
filesystem: str,
393+
s3_region: str,
394+
s3_bucket: str | None,
395+
) -> FileSystem:
391396
"""Get the filesystem to use."""
392397
predef_dirs = [e.value for e in models.UploadFileTypes] + [
393398
e.value for e in models.OutputEndpoints
394399
]
395-
if settings.filesystem == "s3":
400+
if filesystem == "s3":
401+
assert s3_bucket is not None, "S3 bucket must be provided for S3 filesystem"
396402
s3_client = boto3.client(
397403
"s3",
398-
region_name=settings.s3_region,
399-
endpoint_url=f"https://s3.{settings.s3_region}.amazonaws.com",
404+
region_name=s3_region,
405+
endpoint_url=f"https://s3.{s3_region}.amazonaws.com",
400406
config=Config(signature_version="v4", s3={"addressing_style": "path"}),
401407
)
402408
# this and config=... required to avoid DNS problems with new buckets
403409
s3_client.meta.events.unregister("before-sign.s3", fix_s3_host)
404-
return S3Filesystem(
405-
root_path, s3_client, cast(str, settings.s3_bucket), predef_dirs=predef_dirs
406-
)
407-
elif settings.filesystem == "local":
410+
return S3Filesystem(root_path, s3_client, s3_bucket, predef_dirs=predef_dirs)
411+
elif filesystem == "local":
408412
return LocalFilesystem(root_path, predef_dirs=predef_dirs)
409413
else:
410414
raise ValueError("Invalid filesystem setting")
411415

412416

413-
def get_user_filesystem(user_id: str) -> FileSystem:
417+
def user_filesystem_getter(
418+
user_data_root_path: str,
419+
filesystem: str,
420+
s3_region: str,
421+
s3_bucket: str | None,
422+
) -> Callable[[str], FileSystem]:
414423
"""Get the filesystem to use for a user."""
415-
return get_filesystem_with_root(str(Path(settings.user_data_root_path) / user_id))
424+
return lambda user_id: get_filesystem_with_root(
425+
str(Path(user_data_root_path) / user_id),
426+
filesystem=filesystem,
427+
s3_region=s3_region,
428+
s3_bucket=s3_bucket,
429+
)

api/crud/job.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from sqlalchemy.orm import Session
77

88
from api import models, settings
9-
from api.core.filesystem import FileSystem, get_user_filesystem
9+
from api.core.filesystem import FileSystem
1010
from api.schemas import job as schemas
1111

1212

1313
def enqueue_job(
14-
job: models.Job, enqueueing_func: Callable[[schemas.QueueJob], None]
14+
job: models.Job,
15+
filesystem: FileSystem,
16+
enqueueing_func: Callable[[schemas.QueueJob], None],
1517
) -> None:
16-
user_fs = get_user_filesystem(user_id=job.user_id)
17-
1818
app = job.application
1919
job_config = settings.application_config.config[app["application"]][app["version"]][
2020
app["entrypoint"]
@@ -47,14 +47,14 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
4747
f"artifact/{artifact_id}"
4848
for artifact_id in job.attributes["files_down"]["artifact_ids"]
4949
]
50-
_validate_files(user_fs, [config_path] + data_paths + artifact_paths)
50+
_validate_files(filesystem, [config_path] + data_paths + artifact_paths)
5151
roots_down = handler_config["files_down"]
52-
files_down = prepare_files(config_path, roots_down["config_id"], user_fs)
52+
files_down = prepare_files(config_path, roots_down["config_id"], filesystem)
5353
for data_path in data_paths:
54-
files_down.update(prepare_files(data_path, roots_down["data_ids"], user_fs))
54+
files_down.update(prepare_files(data_path, roots_down["data_ids"], filesystem))
5555
for artifact_path in artifact_paths:
5656
files_down.update(
57-
prepare_files(artifact_path, roots_down["artifact_ids"], user_fs)
57+
prepare_files(artifact_path, roots_down["artifact_ids"], filesystem)
5858
)
5959

6060
app_specs = schemas.AppSpecs(
@@ -76,9 +76,9 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
7676
)
7777

7878
paths_upload = {
79-
"output": user_fs.full_path_uri(job.paths_out["output"]),
80-
"log": user_fs.full_path_uri(job.paths_out["log"]),
81-
"artifact": user_fs.full_path_uri(job.paths_out["artifact"]),
79+
"output": filesystem.full_path_uri(job.paths_out["output"]),
80+
"log": filesystem.full_path_uri(job.paths_out["log"]),
81+
"artifact": filesystem.full_path_uri(job.paths_out["artifact"]),
8282
}
8383

8484
queue_item = schemas.QueueJob(
@@ -117,6 +117,7 @@ def _validate_files(filesystem: FileSystem, paths: list[str]) -> None:
117117

118118
def create_job(
119119
db: Session,
120+
filesystem: FileSystem,
120121
enqueueing_func: Callable[[schemas.QueueJob], None],
121122
job: schemas.JobCreate,
122123
user_id: int,
@@ -146,18 +147,17 @@ def create_job(
146147
status_code=status.HTTP_400_BAD_REQUEST,
147148
detail=ve,
148149
)
149-
enqueue_job(db_job, enqueueing_func)
150+
enqueue_job(db_job, filesystem, enqueueing_func)
150151
db.commit()
151152
db.refresh(db_job)
152153
return db_job
153154

154155

155-
def delete_job(db: Session, db_job: models.Job) -> models.Job:
156+
def delete_job(db: Session, filesystem: FileSystem, db_job: models.Job) -> models.Job:
156157
db.delete(db_job)
157-
user_fs = get_user_filesystem(user_id=db_job.user_id)
158158
for path in db_job.paths_out.values():
159159
if path[-1] != "/":
160160
path += "/"
161-
user_fs.delete(path)
161+
filesystem.delete(path)
162162
db.commit()
163163
return db_job

api/dependencies.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from api import settings
1111
from api.core import notifications
12-
from api.core.filesystem import FileSystem, get_user_filesystem
12+
from api.core.filesystem import FileSystem, user_filesystem_getter
1313
from api.schemas.job import QueueJob
1414

1515

@@ -52,11 +52,22 @@ async def current_user_global_dep(
5252
return current_user
5353

5454

55-
async def filesystem_dep(
55+
async def filesystem_getter_dep() -> Callable[[str], FileSystem]:
56+
"""Get the user's filesystem getter."""
57+
return user_filesystem_getter(
58+
user_data_root_path=settings.user_data_root_path,
59+
filesystem=settings.filesystem,
60+
s3_region=settings.s3_region,
61+
s3_bucket=settings.s3_bucket,
62+
)
63+
64+
65+
async def user_filesystem_dep(
66+
filesystem_getter: Callable[[str], FileSystem] = Depends(filesystem_getter_dep),
5667
current_user: CognitoClaims = Depends(current_user_dep),
5768
) -> FileSystem:
5869
"""Get the user's filesystem."""
59-
return get_user_filesystem(current_user.username)
70+
return filesystem_getter(current_user.username)
6071

6172

6273
class APIKeyDependency:

api/endpoints/auth.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
as the authentication is handled by the Cognito service.
55
"""
66

7+
from typing import Callable
8+
79
import boto3
810
from fastapi import APIRouter, Depends, HTTPException, status
911
from fastapi.responses import JSONResponse
1012
from fastapi.security import OAuth2PasswordRequestForm
1113

1214
from api.core.aws import calculate_secret_hash
13-
from api.core.filesystem import get_user_filesystem
15+
from api.core.filesystem import FileSystem
1416
from api.schemas.token import TokenResponse
1517
from api.schemas.user import User, UserGroups
1618
from api.settings import cognito_client_id, cognito_secret, cognito_user_pool_id
@@ -25,7 +27,9 @@
2527
description="Register a new user",
2628
)
2729
def register_user(
28-
user: OAuth2PasswordRequestForm = Depends(), groups: list[UserGroups] | None = None
30+
user: OAuth2PasswordRequestForm = Depends(),
31+
filesystem_getter_dep: Callable[[str], FileSystem] = Depends(),
32+
groups: list[UserGroups] | None = None,
2933
) -> User:
3034
client = boto3.client("cognito-idp")
3135
try:
@@ -52,7 +56,7 @@ def register_user(
5256
Password=user.password,
5357
Permanent=True,
5458
)
55-
filesystem = get_user_filesystem(response["User"]["Username"])
59+
filesystem = filesystem_getter_dep(response["User"]["Username"])
5660
filesystem.init()
5761
except client.exceptions.ClientError as e:
5862
if e.response["Error"]["Code"] == "UsernameExistsException":

api/endpoints/files.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from api import models
88
from api.core.filesystem import FileSystem
9-
from api.dependencies import filesystem_dep
9+
from api.dependencies import user_filesystem_dep
1010
from api.schemas import file as file_schemas
1111

1212
router = APIRouter()
@@ -19,7 +19,7 @@
1919
description="Download a file",
2020
)
2121
def download_file(
22-
file_path: str, filesystem: FileSystem = Depends(filesystem_dep)
22+
file_path: str, filesystem: FileSystem = Depends(user_filesystem_dep)
2323
) -> FileResponse | StreamingResponse:
2424
try:
2525
return filesystem.download(file_path)
@@ -33,7 +33,9 @@ def download_file(
3333
description="Get request parameters (pre-signed URL) to download a file",
3434
)
3535
def get_download_presigned_url(
36-
file_path: str, request: Request, filesystem: FileSystem = Depends(filesystem_dep)
36+
file_path: str,
37+
request: Request,
38+
filesystem: FileSystem = Depends(user_filesystem_dep),
3739
) -> file_schemas.FileHTTPRequest:
3840
try:
3941
return filesystem.download_url(
@@ -52,7 +54,7 @@ def list_files(
5254
base_path: str = "",
5355
show_dirs: bool = True,
5456
recursive: bool = False,
55-
filesystem: FileSystem = Depends(filesystem_dep),
57+
filesystem: FileSystem = Depends(user_filesystem_dep),
5658
) -> list[file_schemas.FileInfo]:
5759
try:
5860
return sorted(
@@ -73,7 +75,7 @@ def upload_file(
7375
f_type: models.UploadFileTypes,
7476
base_path: str,
7577
file: UploadFile,
76-
filesystem: FileSystem = Depends(filesystem_dep),
78+
filesystem: FileSystem = Depends(user_filesystem_dep),
7779
) -> file_schemas.FileInfo:
7880
base_path = f"{f_type.value}/" + base_path
7981
file_path = os.path.join(base_path, file.filename or "unnamed")
@@ -91,7 +93,7 @@ def get_upload_presigned_url(
9193
f_type: models.UploadFileTypes,
9294
base_path: str,
9395
request: Request,
94-
filesystem: FileSystem = Depends(filesystem_dep),
96+
filesystem: FileSystem = Depends(user_filesystem_dep),
9597
) -> file_schemas.FileHTTPRequest:
9698
base_path = f"{f_type.value}/" + base_path
9799
return filesystem.create_file_url(
@@ -107,7 +109,7 @@ def get_upload_presigned_url(
107109
def create_directory(
108110
f_type: models.UploadFileTypes,
109111
base_path: str,
110-
filesystem: FileSystem = Depends(filesystem_dep),
112+
filesystem: FileSystem = Depends(user_filesystem_dep),
111113
) -> None:
112114
return filesystem.create_directory(f"{f_type.value}/{base_path}/")
113115

@@ -120,7 +122,7 @@ def create_directory(
120122
def rename_file(
121123
file_path: str,
122124
file: file_schemas.FileUpdate,
123-
filesystem: FileSystem = Depends(filesystem_dep),
125+
filesystem: FileSystem = Depends(user_filesystem_dep),
124126
) -> file_schemas.FileInfo:
125127
try:
126128
filesystem.rename(file_path, file.path)
@@ -141,6 +143,6 @@ def rename_file(
141143
description="Delete a file or directory",
142144
)
143145
def delete_file(
144-
file_path: str, filesystem: FileSystem = Depends(filesystem_dep)
146+
file_path: str, filesystem: FileSystem = Depends(user_filesystem_dep)
145147
) -> None:
146148
filesystem.delete(file_path)

api/endpoints/jobs.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from sqlalchemy.orm import Session
55

66
import api.database as database
7+
from api.core.filesystem import FileSystem
78
from api.crud import job as crud
8-
from api.dependencies import enqueueing_function_dep
9+
from api.dependencies import enqueueing_function_dep, user_filesystem_dep
910
from api.schemas.job import Job, JobCreate, QueueJob
1011
from api.settings import application_config
1112

@@ -60,11 +61,13 @@ def start_job(
6061
request: Request,
6162
job: JobCreate,
6263
db: Session = Depends(database.get_db),
64+
filesystem: FileSystem = Depends(user_filesystem_dep),
6365
enqueueing_func: Callable[[QueueJob], None] = Depends(enqueueing_function_dep),
6466
) -> Job:
6567
try:
6668
return crud.create_job(
6769
db,
70+
filesystem,
6871
enqueueing_func,
6972
job,
7073
user_id=request.state.current_user.username,
@@ -81,11 +84,14 @@ def start_job(
8184
description="Delete a job",
8285
)
8386
def delete_job(
84-
request: Request, job_id: int, db: Session = Depends(database.get_db)
87+
request: Request,
88+
job_id: int,
89+
db: Session = Depends(database.get_db),
90+
filesystem: FileSystem = Depends(user_filesystem_dep),
8591
) -> None:
8692
db_job = crud.get_job(db, job_id)
8793
if db_job is None or db_job.user_id != request.state.current_user.username:
8894
raise HTTPException(
8995
status_code=status.HTTP_404_NOT_FOUND, detail="Job not found"
9096
)
91-
crud.delete_job(db, db_job)
97+
crud.delete_job(db, filesystem, db_job)

api/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _load_possibly_aws_secret(name: str) -> str | None:
2727
if os.environ.get("DATABASE_SECRET"): # set and not None
2828
database_secret = _load_possibly_aws_secret("DATABASE_SECRET")
2929
database_url = database_url.format(database_secret)
30-
filesystem = os.environ.get("FILESYSTEM")
30+
filesystem = os.environ.get("FILESYSTEM", "local")
3131
s3_bucket = os.environ.get("S3_BUCKET")
3232
s3_region = os.environ.get("S3_REGION", "eu-central-1")
3333
user_data_root_path = os.environ.get("USER_DATA_ROOT_PATH", "/data")

tests/conftest.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import secrets
33
import time
4-
from typing import Any, Generator
4+
from typing import Any
55
from unittest.mock import MagicMock
66

77
import boto3
@@ -18,16 +18,10 @@
1818
REGION_NAME: BucketLocationConstraintType = "eu-central-1"
1919

2020

21-
@pytest.fixture(scope="session")
22-
def monkeypatch_module() -> Generator[pytest.MonkeyPatch, Any, None]:
23-
with pytest.MonkeyPatch.context() as mp:
24-
yield mp
25-
26-
27-
@pytest.fixture(autouse=True, scope="function")
28-
def enqueueing_func(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock:
21+
@pytest.fixture(autouse=True)
22+
def enqueueing_func(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
2923
mock_enqueueing_function = MagicMock()
30-
monkeypatch_module.setitem(
24+
monkeypatch.setitem(
3125
app.dependency_overrides,
3226
enqueueing_function_dep, # type: ignore
3327
lambda: mock_enqueueing_function,

0 commit comments

Comments
 (0)