Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions syncmaster/db/repositories/credentials_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

from typing import TYPE_CHECKING, NoReturn

from sqlalchemy import ScalarResult, insert, select
from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound
from sqlalchemy import ScalarResult, delete, insert, select
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from syncmaster.db.models import AuthData
from syncmaster.db.repositories.base import Repository
from syncmaster.db.repositories.utils import decrypt_auth_data, encrypt_auth_data
from syncmaster.exceptions import SyncmasterError
from syncmaster.exceptions.credentials import AuthDataNotFoundError

if TYPE_CHECKING:
from syncmaster.scheduler.settings import SchedulerAppSettings
Expand All @@ -33,13 +32,13 @@ def __init__(
async def read(
self,
connection_id: int,
) -> dict:
) -> dict | None:
query = select(AuthData).where(AuthData.connection_id == connection_id)
try:
result: ScalarResult[AuthData] = await self._session.scalars(query)
return decrypt_auth_data(result.one().value, settings=self._settings)
except NoResultFound as e:
raise AuthDataNotFoundError(f"Connection id = {connection_id}") from e
result: ScalarResult[AuthData] = await self._session.scalars(query)
result_row = result.one_or_none()
if not result_row:
return None
return decrypt_auth_data(result_row.value, settings=self._settings)

async def read_bulk(
self,
Expand Down Expand Up @@ -79,5 +78,17 @@ async def update(
except IntegrityError as e:
self._raise_error(e)

async def delete(
self,
connection_id: int,
) -> AuthData:
try:
query = delete(AuthData).where(AuthData.connection_id == connection_id).returning(AuthData)
result = await self._session.scalars(query)
await self._session.flush()
return result.one()
except IntegrityError as e:
self._raise_error(e)

def _raise_error(self, err: DBAPIError) -> NoReturn:
raise SyncmasterError from err
18 changes: 11 additions & 7 deletions syncmaster/db/repositories/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@ async def read_by_id(self, run_id: int) -> Run:
async def create(
self,
transfer_id: int,
source_creds: dict,
target_creds: dict,
source_auth_data: dict | None,
target_auth_data: dict | None,
type: RunType,
) -> Run:
run = Run()
run.transfer_id = transfer_id
run.transfer_dump = await self.read_full_serialized_transfer(transfer_id, source_creds, target_creds)
run.transfer_dump = await self.read_full_serialized_transfer(
transfer_id,
source_auth_data,
target_auth_data,
)
run.type = type
try:
self._session.add(run)
Expand All @@ -84,8 +88,8 @@ async def stop(self, run_id: int) -> Run:
async def read_full_serialized_transfer(
self,
transfer_id: int,
source_creds: dict,
target_creds: dict,
source_auth_data: dict | None,
target_auth_data: dict | None,
) -> dict[str, Any]:
transfer = await self._session.scalars(
select(Transfer)
Expand Down Expand Up @@ -116,15 +120,15 @@ async def read_full_serialized_transfer(
name=transfer.source_connection.name,
description=transfer.source_connection.description,
data=transfer.source_connection.data,
auth_data=source_creds["auth_data"],
auth_data=source_auth_data,
),
target_connection=dict(
id=transfer.target_connection.id,
group_id=transfer.target_connection.group_id,
name=transfer.target_connection.name,
description=transfer.target_connection.description,
data=transfer.target_connection.data,
auth_data=target_creds["auth_data"],
auth_data=target_auth_data,
),
)

Expand Down
4 changes: 2 additions & 2 deletions syncmaster/dto/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ class HiveConnectionDTO(ConnectionDTO):

@dataclass
class HDFSConnectionDTO(ConnectionDTO):
user: str
password: str
cluster: str
user: str | None = None
password: str | None = None
type: ClassVar[str] = "hdfs"


Expand Down
19 changes: 15 additions & 4 deletions syncmaster/scheduler/transfer_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,25 @@ async def send_job_to_celery(transfer_id: int) -> None: # noqa: WPS602, WPS217
except TransferNotFoundError:
return

credentials_source = await unit_of_work.credentials.read(transfer.source_connection_id)
credentials_target = await unit_of_work.credentials.read(transfer.target_connection_id)
source_auth_data: dict | None = None
source_credentials = await unit_of_work.credentials.read(transfer.source_connection_id)
if source_credentials:
# remove secrets from the dump
source_credentials_filtered = ReadAuthDataSchema.model_validate(source_credentials)
source_auth_data = source_credentials_filtered.auth_data.model_dump()

target_auth_data: dict | None = None
target_credentials = await unit_of_work.credentials.read(transfer.target_connection_id)
if target_credentials:
# remove secrets from the dump
target_credentials_filtered = ReadAuthDataSchema.model_validate(target_credentials)
target_auth_data = target_credentials_filtered.auth_data.model_dump()

async with unit_of_work:
run = await unit_of_work.run.create(
transfer_id=transfer_id,
source_creds=ReadAuthDataSchema(auth_data=credentials_source).model_dump(),
target_creds=ReadAuthDataSchema(auth_data=credentials_target).model_dump(),
source_auth_data=source_auth_data,
target_auth_data=target_auth_data,
type=RunType.SCHEDULED,
)

Expand Down
4 changes: 2 additions & 2 deletions syncmaster/schemas/v1/connections/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CreateHDFSConnectionSchema(CreateConnectionBaseSchema):
"Data required to connect to the HDFS cluster. These are the parameters that are specified in the URL request."
),
)
auth_data: CreateBasicAuthSchema = Field(
auth_data: CreateBasicAuthSchema | None = Field(
description="Credentials for authorization",
)

Expand All @@ -44,6 +44,6 @@ class ReadHDFSConnectionSchema(ReadConnectionBaseSchema):


class UpdateHDFSConnectionSchema(CreateHDFSConnectionSchema):
auth_data: UpdateBasicAuthSchema = Field(
auth_data: UpdateBasicAuthSchema | None = Field(
description="Credentials for authorization",
)
4 changes: 2 additions & 2 deletions syncmaster/schemas/v1/connections/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CreateHiveConnectionSchema(CreateConnectionBaseSchema):
"Data required to connect to the database. These are the parameters that are specified in the URL request."
),
)
auth_data: CreateBasicAuthSchema = Field(
auth_data: CreateBasicAuthSchema | None = Field(
description="Credentials for authorization",
)

Expand All @@ -44,6 +44,6 @@ class ReadHiveConnectionSchema(ReadConnectionBaseSchema):


class UpdateHiveConnectionSchema(CreateHiveConnectionSchema):
auth_data: UpdateBasicAuthSchema = Field(
auth_data: UpdateBasicAuthSchema | None = Field(
description="Credentials for authorization",
)
60 changes: 37 additions & 23 deletions syncmaster/server/api/v1/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ async def create_connection(
data=connection_data.data.model_dump(),
)

await unit_of_work.credentials.create(
connection_id=connection.id,
data=connection_data.auth_data.model_dump(),
)
if connection_data.auth_data:
await unit_of_work.credentials.create(
connection_id=connection.id,
data=connection_data.auth_data.model_dump(),
)

credentials = await unit_of_work.credentials.read(connection.id)
return TypeAdapter(ReadConnectionSchema).validate_python(
Expand Down Expand Up @@ -183,7 +184,7 @@ async def read_connection(


@router.put("/connections/{connection_id}")
async def update_connection( # noqa: WPS217, WPS238
async def update_connection( # noqa: WPS217, WPS238, WPS231
connection_id: int,
connection_data: UpdateConnectionSchema,
current_user: User = Depends(get_user(is_active=True)),
Expand All @@ -200,36 +201,49 @@ async def update_connection( # noqa: WPS217, WPS238
if resource_role < Permission.WRITE:
raise ActionNotAllowedError

async with unit_of_work:
existing_connection: Connection = await unit_of_work.connection.read_by_id(connection_id=connection_id)
if connection_data.type != existing_connection.type:
linked_transfers: Sequence[Transfer] = await unit_of_work.transfer.list_by_connection_id(connection_id)
if linked_transfers:
raise ConnectionTypeUpdateError

existing_credentials = await unit_of_work.credentials.read(connection_id=connection_id)
auth_data = connection_data.auth_data.model_dump()
existing_connection: Connection = await unit_of_work.connection.read_by_id(connection_id=connection_id)
if connection_data.type != existing_connection.type:
linked_transfers: Sequence[Transfer] = await unit_of_work.transfer.list_by_connection_id(connection_id)
if linked_transfers:
raise ConnectionTypeUpdateError

existing_credentials = await unit_of_work.credentials.read(connection_id=connection_id)
new_credentials: dict | None = None
if connection_data.auth_data:
new_credentials = connection_data.auth_data.model_dump()
secret_field = connection_data.auth_data.secret_field
if new_credentials[secret_field] is None:

if auth_data[secret_field] is None:
if existing_credentials["type"] != auth_data["type"]:
# We don't return secret_field to client, so default field value means using existing secret
if not existing_credentials or existing_credentials["type"] != new_credentials["type"]:
raise ConnectionAuthDataUpdateError

auth_data[secret_field] = existing_credentials[secret_field]
new_credentials[secret_field] = existing_credentials[secret_field]

async with unit_of_work:
connection = await unit_of_work.connection.update(
connection_id=connection_id,
name=connection_data.name,
type=connection_data.type,
description=connection_data.description,
data=connection_data.data.model_dump(),
)
await unit_of_work.credentials.update(
connection_id=connection_id,
data=auth_data,
)

credentials = await unit_of_work.credentials.read(connection_id)
if existing_credentials and new_credentials:
await unit_of_work.credentials.update(
connection_id=connection_id,
data=new_credentials,
)
elif new_credentials:
await unit_of_work.credentials.create(
connection_id=connection.id,
data=new_credentials,
)
elif existing_credentials:
await unit_of_work.credentials.delete(
connection_id=connection_id,
)

return TypeAdapter(ReadConnectionSchema).validate_python(
{
"id": connection.id,
Expand All @@ -238,7 +252,7 @@ async def update_connection( # noqa: WPS217, WPS238
"description": connection.description,
"type": connection.type,
"data": connection.data,
"auth_data": credentials,
"auth_data": new_credentials,
},
)

Expand Down
25 changes: 15 additions & 10 deletions syncmaster/server/api/v1/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,25 @@ async def start_run( # noqa: WPS217

# The credentials.read method is used rather than credentials.read_bulk deliberately
# it’s more convenient to transfer credits in this place
credentials_source = await unit_of_work.credentials.read(
transfer.source_connection_id,
)
credentials_target = await unit_of_work.credentials.read(
transfer.target_connection_id,
)
source_auth_data: dict | None = None
source_credentials = await unit_of_work.credentials.read(transfer.source_connection_id)
if source_credentials:
# remove secrets from the dump
source_credentials_filtered = ReadAuthDataSchema.model_validate(source_credentials)
source_auth_data = source_credentials_filtered.auth_data.model_dump()

target_auth_data: dict | None = None
target_credentials = await unit_of_work.credentials.read(transfer.target_connection_id)
if target_credentials:
# remove secrets from the dump
target_credentials_filtered = ReadAuthDataSchema.model_validate(target_credentials)
target_auth_data = target_credentials_filtered.auth_data.model_dump()

async with unit_of_work:
run = await unit_of_work.run.create(
transfer_id=create_run_data.transfer_id,
# Since fields with credentials may have different names (for example, S3 and Postgres have different names)
# the work of checking fields and removing passwords is delegated to the ReadAuthDataSchema class
source_creds=ReadAuthDataSchema(auth_data=credentials_source).model_dump(),
target_creds=ReadAuthDataSchema(auth_data=credentials_target).model_dump(),
source_auth_data=source_auth_data,
target_auth_data=target_auth_data,
type=RunType.MANUAL,
)

Expand Down
8 changes: 4 additions & 4 deletions syncmaster/worker/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def __init__(
settings: WorkerAppSettings,
run: Run,
source_connection: Connection,
source_auth_data: dict,
source_auth_data: dict | None,
target_connection: Connection,
target_auth_data: dict,
target_auth_data: dict | None,
):
self.temp_dir = TemporaryDirectory(prefix=f"syncmaster_{run.id}_")

Expand Down Expand Up @@ -213,7 +213,7 @@ def perform_transfer(self) -> None:
def get_handler(
self,
connection_data: dict[str, Any],
connection_auth_data: dict,
connection_auth_data: dict | None,
run_data: dict[str, Any],
transfer_id: int,
transfer_params: dict[str, Any],
Expand All @@ -222,7 +222,7 @@ def get_handler(
transformations: list[dict],
temp_dir: TemporaryDirectory,
) -> Handler:
connection_data.update(connection_auth_data)
connection_data.update(connection_auth_data or {})
connection_data.pop("type")
handler_type = transfer_params.pop("type", None)

Expand Down
13 changes: 10 additions & 3 deletions syncmaster/worker/handlers/file/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def connect(self, spark: SparkSession):
spark=spark,
).check()

self.file_connection = HDFS(
cluster=self.connection_dto.cluster,
).check()
if self.connection_dto.user and self.connection_dto.password:
self.file_connection = HDFS(
cluster=self.connection_dto.cluster,
user=self.connection_dto.user,
password=self.connection_dto.password,
).check()
else:
self.file_connection = HDFS(
cluster=self.connection_dto.cluster,
).check()
6 changes: 4 additions & 2 deletions syncmaster/worker/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def run_transfer(session: Session, run_id: int, settings: WorkerAppSettings):
q_source_auth_data = select(AuthData).where(AuthData.connection_id == run.transfer.source_connection.id)
q_target_auth_data = select(AuthData).where(AuthData.connection_id == run.transfer.target_connection.id)

target_auth_data = decrypt_auth_data(session.scalars(q_target_auth_data).one().value, settings)
source_auth_data = decrypt_auth_data(session.scalars(q_source_auth_data).one().value, settings)
source_auth_result = session.scalars(q_source_auth_data).one_or_none()
target_auth_result = session.scalars(q_target_auth_data).one_or_none()
source_auth_data = decrypt_auth_data(source_auth_result.value, settings) if source_auth_result else None
target_auth_data = decrypt_auth_data(target_auth_result.value, settings) if target_auth_result else None

try:
controller = TransferController(
Expand Down
8 changes: 2 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,12 @@ def celery(worker_settings: WorkerAppSettings) -> Celery:

@pytest_asyncio.fixture
async def create_connection_data(request):
if hasattr(request, "param"):
return request.param
return None
return request.param


@pytest_asyncio.fixture
async def create_transfer_data(request):
if hasattr(request, "param"):
return request.param
return None
return request.param


@pytest_asyncio.fixture(
Expand Down
Loading
Loading