diff --git a/syncmaster/db/repositories/credentials_repository.py b/syncmaster/db/repositories/credentials_repository.py index 54f786b2..8fb18645 100644 --- a/syncmaster/db/repositories/credentials_repository.py +++ b/syncmaster/db/repositories/credentials_repository.py @@ -4,15 +4,14 @@ from typing import TYPE_CHECKING, NoReturn -from sqlalchemy import ScalarResult, insert, select -from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound +from sqlalchemy import ScalarResult, delete, insert, select +from sqlalchemy.exc import DBAPIError, IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from syncmaster.db.models import AuthData from syncmaster.db.repositories.base import Repository from syncmaster.db.repositories.utils import decrypt_auth_data, encrypt_auth_data from syncmaster.exceptions import SyncmasterError -from syncmaster.exceptions.credentials import AuthDataNotFoundError if TYPE_CHECKING: from syncmaster.scheduler.settings import SchedulerAppSettings @@ -33,13 +32,13 @@ def __init__( async def read( self, connection_id: int, - ) -> dict: + ) -> dict | None: query = select(AuthData).where(AuthData.connection_id == connection_id) - try: - result: ScalarResult[AuthData] = await self._session.scalars(query) - return decrypt_auth_data(result.one().value, settings=self._settings) - except NoResultFound as e: - raise AuthDataNotFoundError(f"Connection id = {connection_id}") from e + result: ScalarResult[AuthData] = await self._session.scalars(query) + result_row = result.one_or_none() + if not result_row: + return None + return decrypt_auth_data(result_row.value, settings=self._settings) async def read_bulk( self, @@ -79,5 +78,17 @@ async def update( except IntegrityError as e: self._raise_error(e) + async def delete( + self, + connection_id: int, + ) -> AuthData: + try: + query = delete(AuthData).where(AuthData.connection_id == connection_id).returning(AuthData) + result = await self._session.scalars(query) + await self._session.flush() + return result.one() + except IntegrityError as e: + self._raise_error(e) + def _raise_error(self, err: DBAPIError) -> NoReturn: raise SyncmasterError from err diff --git a/syncmaster/db/repositories/run.py b/syncmaster/db/repositories/run.py index 6f031505..29dd4fb7 100644 --- a/syncmaster/db/repositories/run.py +++ b/syncmaster/db/repositories/run.py @@ -52,13 +52,17 @@ async def read_by_id(self, run_id: int) -> Run: async def create( self, transfer_id: int, - source_creds: dict, - target_creds: dict, + source_auth_data: dict | None, + target_auth_data: dict | None, type: RunType, ) -> Run: run = Run() run.transfer_id = transfer_id - run.transfer_dump = await self.read_full_serialized_transfer(transfer_id, source_creds, target_creds) + run.transfer_dump = await self.read_full_serialized_transfer( + transfer_id, + source_auth_data, + target_auth_data, + ) run.type = type try: self._session.add(run) @@ -84,8 +88,8 @@ async def stop(self, run_id: int) -> Run: async def read_full_serialized_transfer( self, transfer_id: int, - source_creds: dict, - target_creds: dict, + source_auth_data: dict | None, + target_auth_data: dict | None, ) -> dict[str, Any]: transfer = await self._session.scalars( select(Transfer) @@ -116,7 +120,7 @@ async def read_full_serialized_transfer( name=transfer.source_connection.name, description=transfer.source_connection.description, data=transfer.source_connection.data, - auth_data=source_creds["auth_data"], + auth_data=source_auth_data, ), target_connection=dict( id=transfer.target_connection.id, @@ -124,7 +128,7 @@ async def read_full_serialized_transfer( name=transfer.target_connection.name, description=transfer.target_connection.description, data=transfer.target_connection.data, - auth_data=target_creds["auth_data"], + auth_data=target_auth_data, ), ) diff --git a/syncmaster/dto/connections.py b/syncmaster/dto/connections.py index 28cd115c..010cb14b 100644 --- a/syncmaster/dto/connections.py +++ b/syncmaster/dto/connections.py @@ -75,9 +75,9 @@ class HiveConnectionDTO(ConnectionDTO): @dataclass class HDFSConnectionDTO(ConnectionDTO): - user: str - password: str cluster: str + user: str | None = None + password: str | None = None type: ClassVar[str] = "hdfs" diff --git a/syncmaster/scheduler/transfer_job_manager.py b/syncmaster/scheduler/transfer_job_manager.py index feee8ad8..f602b3d2 100644 --- a/syncmaster/scheduler/transfer_job_manager.py +++ b/syncmaster/scheduler/transfer_job_manager.py @@ -83,14 +83,25 @@ async def send_job_to_celery(transfer_id: int) -> None: # noqa: WPS602, WPS217 except TransferNotFoundError: return - credentials_source = await unit_of_work.credentials.read(transfer.source_connection_id) - credentials_target = await unit_of_work.credentials.read(transfer.target_connection_id) + source_auth_data: dict | None = None + source_credentials = await unit_of_work.credentials.read(transfer.source_connection_id) + if source_credentials: + # remove secrets from the dump + source_credentials_filtered = ReadAuthDataSchema.model_validate(source_credentials) + source_auth_data = source_credentials_filtered.auth_data.model_dump() + + target_auth_data: dict | None = None + target_credentials = await unit_of_work.credentials.read(transfer.target_connection_id) + if target_credentials: + # remove secrets from the dump + target_credentials_filtered = ReadAuthDataSchema.model_validate(target_credentials) + target_auth_data = target_credentials_filtered.auth_data.model_dump() async with unit_of_work: run = await unit_of_work.run.create( transfer_id=transfer_id, - source_creds=ReadAuthDataSchema(auth_data=credentials_source).model_dump(), - target_creds=ReadAuthDataSchema(auth_data=credentials_target).model_dump(), + source_auth_data=source_auth_data, + target_auth_data=target_auth_data, type=RunType.SCHEDULED, ) diff --git a/syncmaster/schemas/v1/connections/hdfs.py b/syncmaster/schemas/v1/connections/hdfs.py index 2a3e06ae..069baf68 100644 --- a/syncmaster/schemas/v1/connections/hdfs.py +++ b/syncmaster/schemas/v1/connections/hdfs.py @@ -32,7 +32,7 @@ class CreateHDFSConnectionSchema(CreateConnectionBaseSchema): "Data required to connect to the HDFS cluster. These are the parameters that are specified in the URL request." ), ) - auth_data: CreateBasicAuthSchema = Field( + auth_data: CreateBasicAuthSchema | None = Field( description="Credentials for authorization", ) @@ -44,6 +44,6 @@ class ReadHDFSConnectionSchema(ReadConnectionBaseSchema): class UpdateHDFSConnectionSchema(CreateHDFSConnectionSchema): - auth_data: UpdateBasicAuthSchema = Field( + auth_data: UpdateBasicAuthSchema | None = Field( description="Credentials for authorization", ) diff --git a/syncmaster/schemas/v1/connections/hive.py b/syncmaster/schemas/v1/connections/hive.py index 094eea9e..1ce4c2c4 100644 --- a/syncmaster/schemas/v1/connections/hive.py +++ b/syncmaster/schemas/v1/connections/hive.py @@ -32,7 +32,7 @@ class CreateHiveConnectionSchema(CreateConnectionBaseSchema): "Data required to connect to the database. These are the parameters that are specified in the URL request." ), ) - auth_data: CreateBasicAuthSchema = Field( + auth_data: CreateBasicAuthSchema | None = Field( description="Credentials for authorization", ) @@ -44,6 +44,6 @@ class ReadHiveConnectionSchema(ReadConnectionBaseSchema): class UpdateHiveConnectionSchema(CreateHiveConnectionSchema): - auth_data: UpdateBasicAuthSchema = Field( + auth_data: UpdateBasicAuthSchema | None = Field( description="Credentials for authorization", ) diff --git a/syncmaster/server/api/v1/connections.py b/syncmaster/server/api/v1/connections.py index 6a698cdb..c525d463 100644 --- a/syncmaster/server/api/v1/connections.py +++ b/syncmaster/server/api/v1/connections.py @@ -125,10 +125,11 @@ async def create_connection( data=connection_data.data.model_dump(), ) - await unit_of_work.credentials.create( - connection_id=connection.id, - data=connection_data.auth_data.model_dump(), - ) + if connection_data.auth_data: + await unit_of_work.credentials.create( + connection_id=connection.id, + data=connection_data.auth_data.model_dump(), + ) credentials = await unit_of_work.credentials.read(connection.id) return TypeAdapter(ReadConnectionSchema).validate_python( @@ -183,7 +184,7 @@ async def read_connection( @router.put("/connections/{connection_id}") -async def update_connection( # noqa: WPS217, WPS238 +async def update_connection( # noqa: WPS217, WPS238, WPS231 connection_id: int, connection_data: UpdateConnectionSchema, current_user: User = Depends(get_user(is_active=True)), @@ -200,23 +201,26 @@ async def update_connection( # noqa: WPS217, WPS238 if resource_role < Permission.WRITE: raise ActionNotAllowedError - async with unit_of_work: - existing_connection: Connection = await unit_of_work.connection.read_by_id(connection_id=connection_id) - if connection_data.type != existing_connection.type: - linked_transfers: Sequence[Transfer] = await unit_of_work.transfer.list_by_connection_id(connection_id) - if linked_transfers: - raise ConnectionTypeUpdateError - - existing_credentials = await unit_of_work.credentials.read(connection_id=connection_id) - auth_data = connection_data.auth_data.model_dump() + existing_connection: Connection = await unit_of_work.connection.read_by_id(connection_id=connection_id) + if connection_data.type != existing_connection.type: + linked_transfers: Sequence[Transfer] = await unit_of_work.transfer.list_by_connection_id(connection_id) + if linked_transfers: + raise ConnectionTypeUpdateError + + existing_credentials = await unit_of_work.credentials.read(connection_id=connection_id) + new_credentials: dict | None = None + if connection_data.auth_data: + new_credentials = connection_data.auth_data.model_dump() secret_field = connection_data.auth_data.secret_field + if new_credentials[secret_field] is None: - if auth_data[secret_field] is None: - if existing_credentials["type"] != auth_data["type"]: + # We don't return secret_field to client, so default field value means using existing secret + if not existing_credentials or existing_credentials["type"] != new_credentials["type"]: raise ConnectionAuthDataUpdateError - auth_data[secret_field] = existing_credentials[secret_field] + new_credentials[secret_field] = existing_credentials[secret_field] + async with unit_of_work: connection = await unit_of_work.connection.update( connection_id=connection_id, name=connection_data.name, @@ -224,12 +228,22 @@ async def update_connection( # noqa: WPS217, WPS238 description=connection_data.description, data=connection_data.data.model_dump(), ) - await unit_of_work.credentials.update( - connection_id=connection_id, - data=auth_data, - ) - credentials = await unit_of_work.credentials.read(connection_id) + if existing_credentials and new_credentials: + await unit_of_work.credentials.update( + connection_id=connection_id, + data=new_credentials, + ) + elif new_credentials: + await unit_of_work.credentials.create( + connection_id=connection.id, + data=new_credentials, + ) + elif existing_credentials: + await unit_of_work.credentials.delete( + connection_id=connection_id, + ) + return TypeAdapter(ReadConnectionSchema).validate_python( { "id": connection.id, @@ -238,7 +252,7 @@ async def update_connection( # noqa: WPS217, WPS238 "description": connection.description, "type": connection.type, "data": connection.data, - "auth_data": credentials, + "auth_data": new_credentials, }, ) diff --git a/syncmaster/server/api/v1/runs.py b/syncmaster/server/api/v1/runs.py index 97ad4d94..1e74a5f6 100644 --- a/syncmaster/server/api/v1/runs.py +++ b/syncmaster/server/api/v1/runs.py @@ -103,20 +103,25 @@ async def start_run( # noqa: WPS217 # The credentials.read method is used rather than credentials.read_bulk deliberately # it’s more convenient to transfer credits in this place - credentials_source = await unit_of_work.credentials.read( - transfer.source_connection_id, - ) - credentials_target = await unit_of_work.credentials.read( - transfer.target_connection_id, - ) + source_auth_data: dict | None = None + source_credentials = await unit_of_work.credentials.read(transfer.source_connection_id) + if source_credentials: + # remove secrets from the dump + source_credentials_filtered = ReadAuthDataSchema.model_validate(source_credentials) + source_auth_data = source_credentials_filtered.auth_data.model_dump() + + target_auth_data: dict | None = None + target_credentials = await unit_of_work.credentials.read(transfer.target_connection_id) + if target_credentials: + # remove secrets from the dump + target_credentials_filtered = ReadAuthDataSchema.model_validate(target_credentials) + target_auth_data = target_credentials_filtered.auth_data.model_dump() async with unit_of_work: run = await unit_of_work.run.create( transfer_id=create_run_data.transfer_id, - # Since fields with credentials may have different names (for example, S3 and Postgres have different names) - # the work of checking fields and removing passwords is delegated to the ReadAuthDataSchema class - source_creds=ReadAuthDataSchema(auth_data=credentials_source).model_dump(), - target_creds=ReadAuthDataSchema(auth_data=credentials_target).model_dump(), + source_auth_data=source_auth_data, + target_auth_data=target_auth_data, type=RunType.MANUAL, ) diff --git a/syncmaster/worker/controller.py b/syncmaster/worker/controller.py index 61ec1403..a3bf14c1 100644 --- a/syncmaster/worker/controller.py +++ b/syncmaster/worker/controller.py @@ -156,9 +156,9 @@ def __init__( settings: WorkerAppSettings, run: Run, source_connection: Connection, - source_auth_data: dict, + source_auth_data: dict | None, target_connection: Connection, - target_auth_data: dict, + target_auth_data: dict | None, ): self.temp_dir = TemporaryDirectory(prefix=f"syncmaster_{run.id}_") @@ -213,7 +213,7 @@ def perform_transfer(self) -> None: def get_handler( self, connection_data: dict[str, Any], - connection_auth_data: dict, + connection_auth_data: dict | None, run_data: dict[str, Any], transfer_id: int, transfer_params: dict[str, Any], @@ -222,7 +222,7 @@ def get_handler( transformations: list[dict], temp_dir: TemporaryDirectory, ) -> Handler: - connection_data.update(connection_auth_data) + connection_data.update(connection_auth_data or {}) connection_data.pop("type") handler_type = transfer_params.pop("type", None) diff --git a/syncmaster/worker/handlers/file/hdfs.py b/syncmaster/worker/handlers/file/hdfs.py index 860231c0..aa603b98 100644 --- a/syncmaster/worker/handlers/file/hdfs.py +++ b/syncmaster/worker/handlers/file/hdfs.py @@ -23,6 +23,13 @@ def connect(self, spark: SparkSession): spark=spark, ).check() - self.file_connection = HDFS( - cluster=self.connection_dto.cluster, - ).check() + if self.connection_dto.user and self.connection_dto.password: + self.file_connection = HDFS( + cluster=self.connection_dto.cluster, + user=self.connection_dto.user, + password=self.connection_dto.password, + ).check() + else: + self.file_connection = HDFS( + cluster=self.connection_dto.cluster, + ).check() diff --git a/syncmaster/worker/transfer.py b/syncmaster/worker/transfer.py index 1eb934e5..88245b29 100644 --- a/syncmaster/worker/transfer.py +++ b/syncmaster/worker/transfer.py @@ -57,8 +57,10 @@ def run_transfer(session: Session, run_id: int, settings: WorkerAppSettings): q_source_auth_data = select(AuthData).where(AuthData.connection_id == run.transfer.source_connection.id) q_target_auth_data = select(AuthData).where(AuthData.connection_id == run.transfer.target_connection.id) - target_auth_data = decrypt_auth_data(session.scalars(q_target_auth_data).one().value, settings) - source_auth_data = decrypt_auth_data(session.scalars(q_source_auth_data).one().value, settings) + source_auth_result = session.scalars(q_source_auth_data).one_or_none() + target_auth_result = session.scalars(q_target_auth_data).one_or_none() + source_auth_data = decrypt_auth_data(source_auth_result.value, settings) if source_auth_result else None + target_auth_data = decrypt_auth_data(target_auth_result.value, settings) if target_auth_result else None try: controller = TransferController( diff --git a/tests/conftest.py b/tests/conftest.py index b197015a..ec6ba807 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,16 +169,12 @@ def celery(worker_settings: WorkerAppSettings) -> Celery: @pytest_asyncio.fixture async def create_connection_data(request): - if hasattr(request, "param"): - return request.param - return None + return request.param @pytest_asyncio.fixture async def create_transfer_data(request): - if hasattr(request, "param"): - return request.param - return None + return request.param @pytest_asyncio.fixture( diff --git a/tests/test_integration/test_run_transfer/connection_fixtures/hdfs_fixtures.py b/tests/test_integration/test_run_transfer/connection_fixtures/hdfs_fixtures.py index 77e62e2a..3103162c 100644 --- a/tests/test_integration/test_run_transfer/connection_fixtures/hdfs_fixtures.py +++ b/tests/test_integration/test_run_transfer/connection_fixtures/hdfs_fixtures.py @@ -10,9 +10,8 @@ from syncmaster.db.models import Group from syncmaster.dto.connections import HDFSConnectionDTO -from syncmaster.server.settings import ServerAppSettings as Settings from tests.settings import TestSettings -from tests.test_unit.utils import create_connection, create_credentials, upload_files +from tests.test_unit.utils import create_connection, upload_files logger = logging.getLogger(__name__) @@ -24,8 +23,6 @@ def hdfs(test_settings: TestSettings) -> HDFSConnectionDTO: return HDFSConnectionDTO( cluster=test_settings.TEST_HIVE_CLUSTER, - user=test_settings.TEST_HIVE_USER, - password=test_settings.TEST_HIVE_PASSWORD, ) @@ -99,7 +96,6 @@ def prepare_hdfs( @pytest_asyncio.fixture async def hdfs_connection( hdfs: HDFSConnectionDTO, - settings: Settings, session: AsyncSession, group: Group, ): @@ -113,16 +109,7 @@ async def hdfs_connection( group_id=group.id, ) - await create_credentials( - session=session, - settings=settings, - connection_id=result.id, - auth_data=dict( - type="basic", - user=hdfs.user, - password=hdfs.password, - ), - ) + # no credentials for test purpose yield result await session.delete(result) diff --git a/tests/test_integration/test_run_transfer/connection_fixtures/hive_fixtures.py b/tests/test_integration/test_run_transfer/connection_fixtures/hive_fixtures.py index d76953a3..4faed9e9 100644 --- a/tests/test_integration/test_run_transfer/connection_fixtures/hive_fixtures.py +++ b/tests/test_integration/test_run_transfer/connection_fixtures/hive_fixtures.py @@ -10,9 +10,8 @@ from syncmaster.db.models import Group from syncmaster.dto.connections import HiveConnectionDTO -from syncmaster.server.settings import ServerAppSettings as Settings from tests.settings import TestSettings -from tests.test_unit.utils import create_connection, create_credentials +from tests.test_unit.utils import create_connection logger = logging.getLogger(__name__) @@ -61,7 +60,6 @@ def fill_with_data(df: DataFrame): @pytest_asyncio.fixture async def hive_connection( hive: HiveConnectionDTO, - settings: Settings, session: AsyncSession, group: Group, ): @@ -75,16 +73,7 @@ async def hive_connection( group_id=group.id, ) - await create_credentials( - session=session, - settings=settings, - connection_id=result.id, - auth_data=dict( - type="basic", - user=hive.user, - password=hive.password, - ), - ) + # no credentials for test purpose yield result await session.delete(result) diff --git a/tests/test_unit/conftest.py b/tests/test_unit/conftest.py index 2e74ef94..9edd77ab 100644 --- a/tests/test_unit/conftest.py +++ b/tests/test_unit/conftest.py @@ -230,11 +230,9 @@ async def mock_queue( await session.commit() -@pytest_asyncio.fixture +@pytest_asyncio.fixture(params=[None]) async def create_connection_data(request) -> dict | None: - if hasattr(request, "param"): - return request.param - return None + return request.param @pytest_asyncio.fixture @@ -294,18 +292,22 @@ async def two_group_connections( await session.commit() -@pytest_asyncio.fixture +@pytest_asyncio.fixture( + params=[ + { + "type": "basic", + "user": "user", + "password": "password", + }, + ], +) async def create_connection_auth_data(request) -> dict | None: - if hasattr(request, "param"): - return request.param - return None + return request.param -@pytest_asyncio.fixture +@pytest_asyncio.fixture(params=[None]) async def connection_type(request) -> str | None: - if hasattr(request, "param"): - return request.param - return None + return request.param @pytest_asyncio.fixture diff --git a/tests/test_unit/test_connections/connection_fixtures/group_connection_fixture.py b/tests/test_unit/test_connections/connection_fixtures/group_connection_fixture.py index 9b325ec4..815c7184 100644 --- a/tests/test_unit/test_connections/connection_fixtures/group_connection_fixture.py +++ b/tests/test_unit/test_connections/connection_fixtures/group_connection_fixture.py @@ -65,17 +65,26 @@ async def group_connection( data=create_connection_data, ) - credentials = await create_credentials( - session=session, - settings=settings, - connection_id=connection.id, - auth_data=create_connection_auth_data, - ) + auth_data: dict | None = None + credentials = None + if create_connection_auth_data: + credentials = await create_credentials( + session=session, + settings=settings, + connection_id=connection.id, + auth_data=create_connection_auth_data, + ) + auth_data = decrypt_auth_data(credentials.value, settings=settings) + token = access_token_factory(group_owner.id) yield MockConnection( - credentials=MockCredentials( - value=decrypt_auth_data(credentials.value, settings=settings), - connection_id=connection.id, + credentials=( + MockCredentials( + value=auth_data, + connection_id=connection.id, + ) + if auth_data + else None ), connection=connection, owner_group=MockGroup( @@ -88,7 +97,8 @@ async def group_connection( members=members, ), ) - await session.delete(credentials) + if credentials: + await session.delete(credentials) await session.delete(connection) await session.delete(group_owner) await session.delete(group) diff --git a/tests/test_unit/test_connections/test_db_connection/test_create_hive_connection.py b/tests/test_unit/test_connections/test_db_connection/test_create_hive_connection.py index e6642979..1e859f65 100644 --- a/tests/test_unit/test_connections/test_db_connection/test_create_hive_connection.py +++ b/tests/test_unit/test_connections/test_db_connection/test_create_hive_connection.py @@ -70,3 +70,56 @@ async def test_developer_plus_can_create_hive_connection( "user": decrypted["user"], }, } + + +async def test_developer_plus_can_create_hive_connection_without_credentials( + client: AsyncClient, + group: MockGroup, + session: AsyncSession, + role_developer_plus: UserTestRoles, +): + user = group.get_member_of_role(role_developer_plus) + + result = await client.post( + "v1/connections", + headers={"Authorization": f"Bearer {user.token}"}, + json={ + "group_id": group.id, + "name": "New connection", + "description": "", + "type": "hive", + "connection_data": { + "cluster": "cluster", + }, + "auth_data": None, + }, + ) + connection = ( + await session.scalars( + select(Connection).filter_by( + name="New connection", + ), + ) + ).first() + + assert result.status_code == 200, result.json() + assert result.json() == { + "id": connection.id, + "group_id": connection.group_id, + "name": connection.name, + "description": connection.description, + "type": connection.type, + "connection_data": { + "cluster": connection.data["cluster"], + }, + "auth_data": None, + } + + creds = ( + await session.scalars( + select(AuthData).filter_by( + connection_id=connection.id, + ), + ) + ).one_or_none() + assert not creds diff --git a/tests/test_unit/test_connections/test_db_connection/test_update_clickhouse_connection.py b/tests/test_unit/test_connections/test_db_connection/test_update_clickhouse_connection.py index 61d36afd..08d2c2df 100644 --- a/tests/test_unit/test_connections/test_db_connection/test_update_clickhouse_connection.py +++ b/tests/test_unit/test_connections/test_db_connection/test_update_clickhouse_connection.py @@ -66,7 +66,7 @@ async def test_developer_plus_can_update_clickhouse_connection( "additional_params": {}, }, "auth_data": { - "type": group_connection.credentials.value["type"], + "type": "basic", "user": "new_user", }, } diff --git a/tests/test_unit/test_connections/test_db_connection/test_update_hive_connection.py b/tests/test_unit/test_connections/test_db_connection/test_update_hive_connection.py new file mode 100644 index 00000000..f30c9af3 --- /dev/null +++ b/tests/test_unit/test_connections/test_db_connection/test_update_hive_connection.py @@ -0,0 +1,241 @@ +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from syncmaster.db.models import AuthData +from syncmaster.db.repositories.utils import decrypt_auth_data +from syncmaster.server.settings import ServerAppSettings as Settings +from tests.mocks import MockConnection, UserTestRoles +from tests.test_unit.utils import fetch_connection_json + +pytestmark = [pytest.mark.asyncio, pytest.mark.server, pytest.mark.hive] + + +@pytest.mark.parametrize( + "connection_type,create_connection_data,create_connection_auth_data", + [ + ( + "hive", + { + "cluster": "cluster", + }, + None, + ), + ], + indirect=["create_connection_data", "create_connection_auth_data"], +) +async def test_developer_plus_can_update_hive_connection_no_credentials( + client: AsyncClient, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + connection_json = await fetch_connection_json(client, user.token, group_connection) + + result = await client.put( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={**connection_json, "connection_data": {"cluster": "new_cluster"}}, + ) + + assert result.status_code == 200, result.json() + assert result.json() == { + "id": group_connection.id, + "name": group_connection.connection.name, + "description": group_connection.description, + "type": group_connection.type, + "group_id": group_connection.group_id, + "connection_data": {"cluster": "new_cluster"}, + "auth_data": None, + } + + +@pytest.mark.parametrize( + "connection_type,create_connection_data,create_connection_auth_data", + [ + ( + "hive", + { + "cluster": "cluster", + }, + None, + ), + ( + "hive", + { + "cluster": "cluster", + }, + { + "type": "basic", + "user": "user", + "password": "password", + }, + ), + ], + indirect=["create_connection_data", "create_connection_auth_data"], +) +@pytest.mark.parametrize( + "new_auth_data", + [ + None, + { + "type": "basic", + "user": "user", + "password": "password", + }, + ], +) +async def test_developer_plus_can_update_hive_connection_replace_credentials_full( + client: AsyncClient, + session: AsyncSession, + settings: Settings, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, + new_auth_data: dict | None, +): + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + connection_json = await fetch_connection_json(client, user.token, group_connection) + + result = await client.put( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={**connection_json, "auth_data": new_auth_data}, + ) + + expected_auth_data = ( + { + "type": new_auth_data["type"], + "user": new_auth_data["user"], + } + if new_auth_data + else None + ) + + assert result.status_code == 200, result.json() + assert result.json() == { + "id": group_connection.id, + "name": group_connection.connection.name, + "description": group_connection.description, + "type": group_connection.type, + "group_id": group_connection.group_id, + "connection_data": group_connection.connection.data, + "auth_data": expected_auth_data, + } + + creds = ( + await session.scalars( + select(AuthData).filter_by( + connection_id=group_connection.id, + ), + ) + ).one_or_none() + if new_auth_data: + decrypted = decrypt_auth_data(creds.value, settings=settings) + assert decrypted == new_auth_data + else: + assert not creds + + +@pytest.mark.parametrize( + "connection_type,create_connection_data,create_connection_auth_data", + [ + ( + "hive", + { + "cluster": "cluster", + }, + { + "type": "basic", + "user": "user", + "password": "password", + }, + ), + ], + indirect=["create_connection_data", "create_connection_auth_data"], +) +async def test_developer_plus_can_update_hive_connection_replace_credentials_partial( + client: AsyncClient, + session: AsyncSession, + settings: Settings, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + new_auth_data = { + "type": "basic", + "user": "user", + # no password + } + connection_json = await fetch_connection_json(client, user.token, group_connection) + + result = await client.put( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={**connection_json, "auth_data": new_auth_data}, + ) + + assert result.status_code == 200, result.json() + assert result.json() == { + "id": group_connection.id, + "name": group_connection.connection.name, + "description": group_connection.description, + "type": group_connection.type, + "group_id": group_connection.group_id, + "connection_data": group_connection.connection.data, + "auth_data": { + "type": group_connection.credentials.value["type"], + "user": group_connection.credentials.value["user"], + }, + } + + creds = ( + await session.scalars( + select(AuthData).filter_by( + connection_id=group_connection.id, + ), + ) + ).one() + decrypted = decrypt_auth_data(creds.value, settings=settings) + assert decrypted == {"type": "basic", "user": "user", "password": "password"} + + +@pytest.mark.parametrize( + "connection_type,create_connection_data,create_connection_auth_data", + [ + ( + "hive", + { + "cluster": "cluster", + }, + None, + ), + ], + indirect=["create_connection_data", "create_connection_auth_data"], +) +async def test_developer_plus_can_update_hive_connection_missing_credentials( + client: AsyncClient, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + new_auth_data = { + "type": "basic", + "user": "user", + } + connection_json = await fetch_connection_json(client, user.token, group_connection) + + result = await client.put( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={**connection_json, "auth_data": new_auth_data}, + ) + + assert result.status_code == 409, result.json() + assert result.json() == { + "error": { + "code": "conflict", + "message": "You cannot update the connection auth type without providing a new secret value.", + "details": None, + }, + } diff --git a/tests/test_unit/test_connections/test_db_connection/test_update_mssql_connection.py b/tests/test_unit/test_connections/test_db_connection/test_update_mssql_connection.py index 6c858825..02202f9b 100644 --- a/tests/test_unit/test_connections/test_db_connection/test_update_mssql_connection.py +++ b/tests/test_unit/test_connections/test_db_connection/test_update_mssql_connection.py @@ -67,7 +67,7 @@ async def test_developer_plus_can_update_mssql_connection( "additional_params": {}, }, "auth_data": { - "type": group_connection.credentials.value["type"], + "type": "basic", "user": "new_user", }, } diff --git a/tests/test_unit/test_connections/test_db_connection/test_update_mysql_connection.py b/tests/test_unit/test_connections/test_db_connection/test_update_mysql_connection.py index 35f8549d..8f432747 100644 --- a/tests/test_unit/test_connections/test_db_connection/test_update_mysql_connection.py +++ b/tests/test_unit/test_connections/test_db_connection/test_update_mysql_connection.py @@ -67,7 +67,7 @@ async def test_developer_plus_can_update_mysql_connection( "additional_params": {}, }, "auth_data": { - "type": group_connection.credentials.value["type"], + "type": "basic", "user": "new_user", }, } diff --git a/tests/test_unit/test_connections/test_file_connection/test_create_hdfs_connection.py b/tests/test_unit/test_connections/test_file_connection/test_create_hdfs_connection.py index f77edf1e..d1cdd4e9 100644 --- a/tests/test_unit/test_connections/test_file_connection/test_create_hdfs_connection.py +++ b/tests/test_unit/test_connections/test_file_connection/test_create_hdfs_connection.py @@ -70,3 +70,56 @@ async def test_developer_plus_can_create_hdfs_connection( "user": decrypted["user"], }, } + + +async def test_developer_plus_can_create_hdfs_connection_without_credentials( + client: AsyncClient, + group: MockGroup, + session: AsyncSession, + role_developer_plus: UserTestRoles, +): + user = group.get_member_of_role(role_developer_plus) + + result = await client.post( + "v1/connections", + headers={"Authorization": f"Bearer {user.token}"}, + json={ + "group_id": group.id, + "name": "New connection", + "description": "", + "type": "hdfs", + "connection_data": { + "cluster": "cluster", + }, + "auth_data": None, + }, + ) + connection = ( + await session.scalars( + select(Connection).filter_by( + name="New connection", + ), + ) + ).first() + + assert result.status_code == 200, result.json() + assert result.json() == { + "id": connection.id, + "group_id": connection.group_id, + "name": connection.name, + "description": connection.description, + "type": connection.type, + "connection_data": { + "cluster": connection.data["cluster"], + }, + "auth_data": None, + } + + creds = ( + await session.scalars( + select(AuthData).filter_by( + connection_id=connection.id, + ), + ) + ).one_or_none() + assert not creds diff --git a/tests/test_unit/test_connections/test_file_connection/test_update_ftp_connection.py b/tests/test_unit/test_connections/test_file_connection/test_update_ftp_connection.py index 2caedc0e..8cf7fd2f 100644 --- a/tests/test_unit/test_connections/test_file_connection/test_update_ftp_connection.py +++ b/tests/test_unit/test_connections/test_file_connection/test_update_ftp_connection.py @@ -29,8 +29,8 @@ async def test_developer_plus_can_update_ftp_connection( client: AsyncClient, group_connection: MockConnection, role_developer_plus: UserTestRoles, - create_connection_data: dict, - create_connection_auth_data: dict, + create_connection_data: dict, # don't remove + create_connection_auth_data: dict, # don't remove ): user = group_connection.owner_group.get_member_of_role(role_developer_plus) new_connection_data = {"host": "new_host", "port": 81} diff --git a/tests/test_unit/test_connections/test_file_connection/test_update_ftps_connection.py b/tests/test_unit/test_connections/test_file_connection/test_update_ftps_connection.py index 3b9a1f96..167f608b 100644 --- a/tests/test_unit/test_connections/test_file_connection/test_update_ftps_connection.py +++ b/tests/test_unit/test_connections/test_file_connection/test_update_ftps_connection.py @@ -29,8 +29,8 @@ async def test_developer_plus_can_update_ftps_connection( client: AsyncClient, group_connection: MockConnection, role_developer_plus: UserTestRoles, - create_connection_data: dict, - create_connection_auth_data: dict, + create_connection_data: dict, # don't remove + create_connection_auth_data: dict, # don't remove ): user = group_connection.owner_group.get_member_of_role(role_developer_plus) new_connection_data = {"host": "new_host", "port": 81} diff --git a/tests/test_unit/test_connections/test_file_connection/test_update_hdfs_connection.py b/tests/test_unit/test_connections/test_file_connection/test_update_hdfs_connection.py index 30183693..afc74550 100644 --- a/tests/test_unit/test_connections/test_file_connection/test_update_hdfs_connection.py +++ b/tests/test_unit/test_connections/test_file_connection/test_update_hdfs_connection.py @@ -1,6 +1,11 @@ import pytest from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from syncmaster.db.models import AuthData +from syncmaster.db.repositories.utils import decrypt_auth_data +from syncmaster.server.settings import ServerAppSettings as Settings from tests.mocks import MockConnection, UserTestRoles from tests.test_unit.utils import fetch_connection_json @@ -10,6 +15,52 @@ @pytest.mark.parametrize( "connection_type,create_connection_data,create_connection_auth_data", [ + ( + "hdfs", + { + "cluster": "cluster", + }, + None, + ), + ], + indirect=["create_connection_data", "create_connection_auth_data"], +) +async def test_developer_plus_can_update_hdfs_connection_no_credentials( + client: AsyncClient, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + connection_json = await fetch_connection_json(client, user.token, group_connection) + + result = await client.put( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={**connection_json, "connection_data": {"cluster": "new_cluster"}}, + ) + + assert result.status_code == 200, result.json() + assert result.json() == { + "id": group_connection.id, + "name": group_connection.connection.name, + "description": group_connection.description, + "type": group_connection.type, + "group_id": group_connection.group_id, + "connection_data": {"cluster": "new_cluster"}, + "auth_data": None, + } + + +@pytest.mark.parametrize( + "connection_type,create_connection_data,create_connection_auth_data", + [ + ( + "hdfs", + { + "cluster": "cluster", + }, + None, + ), ( "hdfs", { @@ -24,21 +75,41 @@ ], indirect=["create_connection_data", "create_connection_auth_data"], ) -async def test_developer_plus_can_update_hdfs_connection( +@pytest.mark.parametrize( + "new_auth_data", + [ + None, + { + "type": "basic", + "user": "user", + "password": "password", + }, + ], +) +async def test_developer_plus_can_update_hdfs_connection_replace_credentials_full( client: AsyncClient, + session: AsyncSession, + settings: Settings, group_connection: MockConnection, role_developer_plus: UserTestRoles, - create_connection_data: dict, - create_connection_auth_data: dict, + new_auth_data: dict | None, ): user = group_connection.owner_group.get_member_of_role(role_developer_plus) - new_connection_data = {"cluster": "new_cluster"} connection_json = await fetch_connection_json(client, user.token, group_connection) result = await client.put( f"v1/connections/{group_connection.id}", headers={"Authorization": f"Bearer {user.token}"}, - json={**connection_json, "type": "hdfs", "connection_data": new_connection_data}, + json={**connection_json, "auth_data": new_auth_data}, + ) + + expected_auth_data = ( + { + "type": new_auth_data["type"], + "user": new_auth_data["user"], + } + if new_auth_data + else None ) assert result.status_code == 200, result.json() @@ -48,9 +119,123 @@ async def test_developer_plus_can_update_hdfs_connection( "description": group_connection.description, "type": group_connection.type, "group_id": group_connection.group_id, - "connection_data": new_connection_data, + "connection_data": group_connection.connection.data, + "auth_data": expected_auth_data, + } + + creds = ( + await session.scalars( + select(AuthData).filter_by( + connection_id=group_connection.id, + ), + ) + ).one_or_none() + if new_auth_data: + decrypted = decrypt_auth_data(creds.value, settings=settings) + assert decrypted == new_auth_data + else: + assert not creds + + +@pytest.mark.parametrize( + "connection_type,create_connection_data,create_connection_auth_data", + [ + ( + "hdfs", + { + "cluster": "cluster", + }, + { + "type": "basic", + "user": "user", + "password": "password", + }, + ), + ], + indirect=["create_connection_data", "create_connection_auth_data"], +) +async def test_developer_plus_can_update_hdfs_connection_replace_credentials_partial( + client: AsyncClient, + session: AsyncSession, + settings: Settings, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + new_auth_data = { + "type": "basic", + "user": "user", + # no password + } + connection_json = await fetch_connection_json(client, user.token, group_connection) + + result = await client.put( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={**connection_json, "auth_data": new_auth_data}, + ) + + assert result.status_code == 200, result.json() + assert result.json() == { + "id": group_connection.id, + "name": group_connection.connection.name, + "description": group_connection.description, + "type": group_connection.type, + "group_id": group_connection.group_id, + "connection_data": group_connection.connection.data, "auth_data": { "type": group_connection.credentials.value["type"], "user": group_connection.credentials.value["user"], }, } + + creds = ( + await session.scalars( + select(AuthData).filter_by( + connection_id=group_connection.id, + ), + ) + ).one() + decrypted = decrypt_auth_data(creds.value, settings=settings) + assert decrypted == {"type": "basic", "user": "user", "password": "password"} + + +@pytest.mark.parametrize( + "connection_type,create_connection_data,create_connection_auth_data", + [ + ( + "hdfs", + { + "cluster": "cluster", + }, + None, + ), + ], + indirect=["create_connection_data", "create_connection_auth_data"], +) +async def test_developer_plus_can_update_hdfs_connection_missing_credentials( + client: AsyncClient, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + new_auth_data = { + "type": "basic", + "user": "user", + } + connection_json = await fetch_connection_json(client, user.token, group_connection) + + result = await client.put( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={**connection_json, "auth_data": new_auth_data}, + ) + + assert result.status_code == 409, result.json() + assert result.json() == { + "error": { + "code": "conflict", + "message": "You cannot update the connection auth type without providing a new secret value.", + "details": None, + }, + } diff --git a/tests/test_unit/test_connections/test_update_connection.py b/tests/test_unit/test_connections/test_update_connection.py index d695a494..a2ca5df5 100644 --- a/tests/test_unit/test_connections/test_update_connection.py +++ b/tests/test_unit/test_connections/test_update_connection.py @@ -345,6 +345,7 @@ async def test_superuser_cannot_update_connection_auth_data_type_without_secret( }, ) + assert result.status_code == 409, result.json() assert result.json() == { "error": { "code": "conflict", @@ -352,7 +353,6 @@ async def test_superuser_cannot_update_connection_auth_data_type_without_secret( "details": None, }, } - assert result.status_code == 409, result.json() async def test_unauthorized_user_cannot_update_connection( diff --git a/tests/test_unit/utils.py b/tests/test_unit/utils.py index 1489c7fd..b3004360 100644 --- a/tests/test_unit/utils.py +++ b/tests/test_unit/utils.py @@ -271,10 +271,11 @@ async def fetch_connection_json(client: AsyncClient, user_token: str, mock_conne connection_json = connection.json() auth_data = connection_json["auth_data"] - if auth_data["type"] in ("basic", "samba"): - auth_data["password"] = mock_connection.credentials.value["password"] - elif auth_data["type"] == "s3": - auth_data["secret_key"] = mock_connection.credentials.value["secret_key"] + if auth_data: + if auth_data["type"] in ("basic", "samba"): + auth_data["password"] = mock_connection.credentials.value["password"] + elif auth_data["type"] == "s3": + auth_data["secret_key"] = mock_connection.credentials.value["secret_key"] return connection_json