diff --git a/app/api/audit/router.py b/app/api/audit/router.py index 4a328e2ef..6209d0740 100644 --- a/app/api/audit/router.py +++ b/app/api/audit/router.py @@ -15,6 +15,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.policies.audit.exception import ( AuditAlreadyExistsError, @@ -59,7 +60,11 @@ async def get_audit_policies( return await audit_adapter.get_policies() -@audit_router.put("/policy/{policy_id}", error_map=error_map) +@audit_router.put( + "/policy/{policy_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update_audit_policy( policy_id: int, policy_data: AuditPolicySchemaRequest, @@ -81,6 +86,7 @@ async def get_audit_destinations( "/destination", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_audit_destination( destination_data: AuditDestinationSchemaRequest, @@ -90,7 +96,11 @@ async def create_audit_destination( return await audit_adapter.create_destination(destination_data) -@audit_router.delete("/destination/{destination_id}", error_map=error_map) +@audit_router.delete( + "/destination/{destination_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_audit_destination( destination_id: int, audit_adapter: FromDishka[AuditPoliciesAdapter], @@ -99,7 +109,11 @@ async def delete_audit_destination( await audit_adapter.delete_destination(destination_id) -@audit_router.put("/destination/{destination_id}", error_map=error_map) +@audit_router.put( + "/destination/{destination_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update_audit_destination( destination_id: int, destination_data: AuditDestinationSchemaRequest, diff --git a/app/api/auth/router_auth.py b/app/api/auth/router_auth.py index ae8df7bfd..ceb4d1ded 100644 --- a/app/api/auth/router_auth.py +++ b/app/api/auth/router_auth.py @@ -19,6 +19,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( MFAAPIError, @@ -186,7 +187,7 @@ async def logout( @auth_router.patch( "/user/password", status_code=200, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def password_reset( @@ -229,6 +230,7 @@ async def check_setup( status_code=status.HTTP_200_OK, responses={423: {"detail": "Locked"}}, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def first_setup( request: SetupRequest, diff --git a/app/api/auth/router_mfa.py b/app/api/auth/router_mfa.py index 18424c8ca..944852a89 100644 --- a/app/api/auth/router_mfa.py +++ b/app/api/auth/router_mfa.py @@ -24,6 +24,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( ForbiddenError, @@ -81,7 +82,7 @@ @mfa_router.post( "/setup", status_code=status.HTTP_201_CREATED, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def setup_mfa( @@ -100,7 +101,7 @@ async def setup_mfa( @mfa_router.delete( "/keys", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def remove_mfa( @@ -113,7 +114,7 @@ async def remove_mfa( @mfa_router.post( "/get", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def get_mfa( diff --git a/app/api/ldap_schema/attribute_type_router.py b/app/api/ldap_schema/attribute_type_router.py index 5a2f1f368..a75a1826a 100644 --- a/app/api/ldap_schema/attribute_type_router.py +++ b/app/api/ldap_schema/attribute_type_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map, ldap_schema_router from api.ldap_schema.adapters.attribute_type import AttributeTypeFastAPIAdapter @@ -16,6 +16,7 @@ AttributeTypeSchema, AttributeTypeUpdateSchema, ) +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -23,6 +24,7 @@ "/attribute_type", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_one_attribute_type( request_data: AttributeTypeSchema[None], @@ -59,6 +61,7 @@ async def get_list_attribute_types_with_pagination( @ldap_schema_router.patch( "/attribute_type/{attribute_type_name}", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def modify_one_attribute_type( attribute_type_name: str, @@ -72,6 +75,7 @@ async def modify_one_attribute_type( @ldap_schema_router.post( "/attribute_types/delete", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def delete_bulk_attribute_types( attribute_types_names: LimitedListType, diff --git a/app/api/ldap_schema/entity_type_router.py b/app/api/ldap_schema/entity_type_router.py index 31de91616..129230b8e 100644 --- a/app/api/ldap_schema/entity_type_router.py +++ b/app/api/ldap_schema/entity_type_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.entity_type import LDAPEntityTypeFastAPIAdapter @@ -17,6 +17,7 @@ EntityTypeSchema, EntityTypeUpdateSchema, ) +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -24,6 +25,7 @@ "/entity_type", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_one_entity_type( request_data: EntityTypeSchema[None], @@ -66,6 +68,7 @@ async def get_entity_type_attributes( @ldap_schema_router.patch( "/entity_type/{entity_type_name}", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def modify_one_entity_type( entity_type_name: str, @@ -76,7 +79,11 @@ async def modify_one_entity_type( await adapter.update(name=entity_type_name, data=request_data) -@ldap_schema_router.post("/entity_type/delete", error_map=error_map) +@ldap_schema_router.post( + "/entity_type/delete", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_bulk_entity_types( entity_type_names: LimitedListType, adapter: FromDishka[LDAPEntityTypeFastAPIAdapter], diff --git a/app/api/ldap_schema/object_class_router.py b/app/api/ldap_schema/object_class_router.py index a351f3b33..a6baced69 100644 --- a/app/api/ldap_schema/object_class_router.py +++ b/app/api/ldap_schema/object_class_router.py @@ -7,7 +7,7 @@ from typing import Annotated from dishka.integrations.fastapi import FromDishka -from fastapi import Query, status +from fastapi import Depends, Query, status from api.ldap_schema import LimitedListType, error_map from api.ldap_schema.adapters.object_class import ObjectClassFastAPIAdapter @@ -17,6 +17,7 @@ ObjectClassSchema, ObjectClassUpdateSchema, ) +from api.utils import require_master_db from ldap_protocol.utils.pagination import PaginationParams @@ -24,6 +25,7 @@ "/object_class", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def create_one_object_class( request_data: ObjectClassSchema[None], @@ -57,6 +59,7 @@ async def get_list_object_classes_with_pagination( @ldap_schema_router.patch( "/object_class/{object_class_name}", error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def modify_one_object_class( object_class_name: str, @@ -67,7 +70,11 @@ async def modify_one_object_class( await adapter.update(object_class_name, request_data) -@ldap_schema_router.post("/object_class/delete", error_map=error_map) +@ldap_schema_router.post( + "/object_class/delete", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_bulk_object_classes( object_classes_names: LimitedListType, adapter: FromDishka[ObjectClassFastAPIAdapter], diff --git a/app/api/main/dns_router.py b/app/api/main/dns_router.py index d93382512..5f797d1e5 100644 --- a/app/api/main/dns_router.py +++ b/app/api/main/dns_router.py @@ -29,6 +29,7 @@ DNSServiceZoneDeleteRequest, DNSServiceZoneUpdateRequest, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.dns import ( DNSForwardServerStatus, @@ -139,7 +140,11 @@ async def get_dns_status( return await adapter.get_dns_status() -@dns_router.post("/setup", error_map=error_map) +@dns_router.post( + "/setup", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def setup_dns( data: DNSServiceSetupRequest, adapter: FromDishka[DNSFastAPIAdapter], diff --git a/app/api/main/krb5_router.py b/app/api/main/krb5_router.py index 91f64a5b6..9ed36515c 100644 --- a/app/api/main/krb5_router.py +++ b/app/api/main/krb5_router.py @@ -24,6 +24,7 @@ ) from api.main.adapters.kerberos import KerberosFastAPIAdapter from api.main.schema import KerberosSetupRequest +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.dialogue import LDAPSession from ldap_protocol.kerberos import KerberosState @@ -82,7 +83,7 @@ "/setup/tree", response_class=Response, error_map=error_map, - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], ) async def setup_krb_catalogue( mail: Annotated[EmailStr, Body()], @@ -106,7 +107,12 @@ async def setup_krb_catalogue( ) -@krb5_router.post("/setup", response_class=Response, error_map=error_map) +@krb5_router.post( + "/setup", + response_class=Response, + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def setup_kdc( data: KerberosSetupRequest, identity_adapter: FromDishka[AuthFastAPIAdapter], @@ -173,7 +179,7 @@ async def get_krb_status( @krb5_router.post( "/principal/add", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def add_principal( @@ -193,7 +199,7 @@ async def add_principal( @krb5_router.patch( "/principal/rename", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def rename_principal( @@ -217,7 +223,7 @@ async def rename_principal( @krb5_router.patch( "/principal/reset", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def reset_principal_pw( @@ -238,7 +244,7 @@ async def reset_principal_pw( @krb5_router.delete( "/principal/delete", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], error_map=error_map, ) async def delete_principal( diff --git a/app/api/main/router.py b/app/api/main/router.py index f26881b38..22ba397d5 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -16,6 +16,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.identity.exceptions import UnauthorizedError from ldap_protocol.ldap_requests import ( @@ -72,7 +73,11 @@ async def search( ) -@entry_router.post("/add", error_map=error_map) +@entry_router.post( + "/add", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def add( request: AddRequest, req: Request, @@ -81,7 +86,11 @@ async def add( return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update", error_map=error_map) +@entry_router.patch( + "/update", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def modify( request: ModifyRequest, req: Request, @@ -90,7 +99,11 @@ async def modify( return await request.handle_api(req.state.dishka_container) -@entry_router.patch("/update_many", error_map=error_map) +@entry_router.patch( + "/update_many", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def modify_many( requests: list[ModifyRequest], req: Request, @@ -102,7 +115,11 @@ async def modify_many( return results -@entry_router.put("/update/dn", error_map=error_map) +@entry_router.put( + "/update/dn", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def modify_dn( request: ModifyDNRequest, req: Request, @@ -111,7 +128,11 @@ async def modify_dn( return await request.handle_api(req.state.dishka_container) -@entry_router.post("/update_many/dn", error_map=error_map) +@entry_router.post( + "/update_many/dn", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def modify_dn_many( requests: list[ModifyDNRequest], req: Request, @@ -123,7 +144,11 @@ async def modify_dn_many( return results -@entry_router.delete("/delete", error_map=error_map) +@entry_router.delete( + "/delete", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete( request: DeleteRequest, req: Request, @@ -132,7 +157,11 @@ async def delete( return await request.handle_api(req.state.dishka_container) -@entry_router.post("/delete_many", error_map=error_map) +@entry_router.post( + "/delete_many", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def delete_many( requests: list[DeleteRequest], req: Request, @@ -144,7 +173,10 @@ async def delete_many( return results -@entry_router.post("/set_primary_group") +@entry_router.post( + "/set_primary_group", + dependencies=[Depends(require_master_db)], +) async def set_primary_group( request: PrimaryGroupRequest, session: FromDishka[AsyncSession], diff --git a/app/api/network/router.py b/app/api/network/router.py index bc65ed858..c261aa73d 100644 --- a/app/api/network/router.py +++ b/app/api/network/router.py @@ -18,6 +18,7 @@ DomainErrorTranslator, ) from api.network.adapters.network import NetworkPolicyFastAPIAdapter +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.policies.network.exceptions import ( LastActivePolicyError, @@ -64,6 +65,7 @@ "", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def add_network_policy( policy: Policy, @@ -97,6 +99,7 @@ async def get_list_network_policies( response_class=RedirectResponse, status_code=status.HTTP_303_SEE_OTHER, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def delete_network_policy( policy_id: int, @@ -114,7 +117,11 @@ async def delete_network_policy( return await adapter.delete(request, policy_id) # type: ignore -@network_router.patch("/{policy_id}", error_map=error_map) +@network_router.patch( + "/{policy_id}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def switch_network_policy( policy_id: int, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -133,7 +140,11 @@ async def switch_network_policy( return await adapter.switch_network_policy(policy_id) -@network_router.put("", error_map=error_map) +@network_router.put( + "", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update_network_policy( request: PolicyUpdate, adapter: FromDishka[NetworkPolicyFastAPIAdapter], @@ -150,7 +161,11 @@ async def update_network_policy( return await adapter.update(request) -@network_router.post("/swap", error_map=error_map) +@network_router.post( + "/swap", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def swap_network_policy( swap: SwapRequest, adapter: FromDishka[NetworkPolicyFastAPIAdapter], diff --git a/app/api/password_policy/password_ban_word_router.py b/app/api/password_policy/password_ban_word_router.py index a0c06a04e..5185124dc 100644 --- a/app/api/password_policy/password_ban_word_router.py +++ b/app/api/password_policy/password_ban_word_router.py @@ -13,6 +13,7 @@ from api.error_routing import DishkaErrorAwareRoute from api.password_policy.adapter import PasswordBanWordsFastAPIAdapter from api.password_policy.error_utils import error_map +from api.utils import require_master_db password_ban_word_router = ErrorAwareRouter( prefix="/password_ban_word", @@ -26,6 +27,7 @@ "/upload_txt", status_code=status.HTTP_201_CREATED, error_map=error_map, + dependencies=[Depends(require_master_db)], ) async def upload_ban_words_txt( file: UploadFile, diff --git a/app/api/password_policy/password_policy_router.py b/app/api/password_policy/password_policy_router.py index 812777ecd..36bd206c3 100644 --- a/app/api/password_policy/password_policy_router.py +++ b/app/api/password_policy/password_policy_router.py @@ -13,6 +13,7 @@ from api.password_policy.adapter import PasswordPolicyFastAPIAdapter from api.password_policy.error_utils import error_map from api.password_policy.schemas import PasswordPolicySchema +from api.utils import require_master_db from ldap_protocol.utils.const import GRANT_DN_STRING from .schemas import PriorityT @@ -51,7 +52,11 @@ async def get_password_policy_by_dir_path_dn( return await adapter.get_password_policy_by_dir_path_dn(path_dn) -@password_policy_router.put("/{id_}", error_map=error_map) +@password_policy_router.put( + "/{id_}", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def update( id_: int, policy: PasswordPolicySchema[PriorityT], @@ -61,7 +66,11 @@ async def update( await adapter.update(id_, policy) -@password_policy_router.put("/reset/domain_policy", error_map=error_map) +@password_policy_router.put( + "/reset/domain_policy", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def reset_domain_policy_to_default_config( adapter: FromDishka[PasswordPolicyFastAPIAdapter], ) -> None: diff --git a/app/api/password_policy/user_password_history_router.py b/app/api/password_policy/user_password_history_router.py index 2285c3cdd..9af233c12 100644 --- a/app/api/password_policy/user_password_history_router.py +++ b/app/api/password_policy/user_password_history_router.py @@ -18,6 +18,7 @@ DomainErrorTranslator, ) from api.password_policy.adapter import UserPasswordHistoryResetFastAPIAdapter +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.identity.exceptions import ( AuthorizationError, @@ -39,7 +40,7 @@ user_password_history_router = ErrorAwareRouter( prefix="/user/password_history", - dependencies=[Depends(verify_auth)], + dependencies=[Depends(verify_auth), Depends(require_master_db)], tags=["User Password history"], route_class=DishkaErrorAwareRoute, ) diff --git a/app/api/shadow/router.py b/app/api/shadow/router.py index ee8938a18..f45093220 100644 --- a/app/api/shadow/router.py +++ b/app/api/shadow/router.py @@ -8,7 +8,7 @@ from typing import Annotated from dishka import FromDishka -from fastapi import Body, status +from fastapi import Body, Depends, status from fastapi_error_map.routing import ErrorAwareRouter from fastapi_error_map.rules import rule @@ -17,6 +17,7 @@ DishkaErrorAwareRoute, DomainErrorTranslator, ) +from api.utils import require_master_db from enums import DomainCodes from ldap_protocol.auth.exceptions.mfa import ( AuthenticationError, @@ -67,7 +68,11 @@ async def proxy_request( return await adapter.proxy_request(principal, ip) -@shadow_router.post("/sync/password", error_map=error_map) +@shadow_router.post( + "/sync/password", + error_map=error_map, + dependencies=[Depends(require_master_db)], +) async def change_password( principal: Annotated[str, Body(embed=True)], new_password: Annotated[str, Body(embed=True)], diff --git a/app/api/utils.py b/app/api/utils.py new file mode 100644 index 000000000..aa3b2e289 --- /dev/null +++ b/app/api/utils.py @@ -0,0 +1,36 @@ +"""Utils with master database check. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dishka import FromDishka +from dishka.integrations.fastapi import inject +from fastapi import HTTPException, status +from loguru import logger +from sqlalchemy import text +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings + + +@inject +async def require_master_db( + session: FromDishka[AsyncSession], + settings: FromDishka[Settings], +) -> None: + if settings.POSTGRES_RW_MODE == "single": + return + + try: + session.sync_session.set_force_master(True) # type: ignore + await session.execute(text("SELECT 1")) + except OperationalError as e: + logger.error(f"Master DB check failed: {e}") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Master DB is not available", + ) + else: + session.sync_session.set_force_master(False) # type: ignore diff --git a/app/config.py b/app/config.py index 423eb2bf8..d813dd16a 100644 --- a/app/config.py +++ b/app/config.py @@ -49,12 +49,20 @@ class Settings(BaseModel): TCP_PACKET_SIZE: int = 1024 COROUTINES_NUM_PER_CLIENT: int = 3 + POSTGRES_RW_MODE: Literal["single", "replication"] = "single" POSTGRES_SCHEMA: ClassVar[str] = "postgresql+psycopg" - POSTGRES_DB: str = "postgres" + POSTGRES_REPLICA_DB: str = "" + POSTGRES_REPLICA_HOST: str = "" + POSTGRES_REPLICA_USER: str = "" + POSTGRES_REPLICA_PASSWORD: str = "" + POSTGRES_REPLICA_CONNECT_TIMEOUT: int = 4 + + POSTGRES_DB: str = "postgres" POSTGRES_HOST: str = "postgres" POSTGRES_USER: str POSTGRES_PASSWORD: str + POSTGRES_CONNECT_TIMEOUT: int = 4 SESSION_STORAGE_URL: RedisDsn = RedisDsn("redis://dragonfly:6379/1") SESSION_KEY_LENGTH: int = 16 @@ -99,6 +107,54 @@ def POSTGRES_URI(self) -> PostgresDsn: # noqa f"{self.POSTGRES_DB}", ) + @computed_field # type: ignore + @cached_property + def REPLICA_POSTGRES_URI(self) -> PostgresDsn: # noqa + """Build replica postgres DSN.""" + return PostgresDsn( + f"{self.POSTGRES_SCHEMA}://" + f"{self.POSTGRES_REPLICA_USER}:" + f"{self.POSTGRES_REPLICA_PASSWORD}@" + f"{self.POSTGRES_REPLICA_HOST}/" + f"{self.POSTGRES_REPLICA_DB}", + ) + + @cached_property + def engine(self) -> AsyncEngine: + """Get engine.""" + return create_async_engine( + str(self.POSTGRES_URI), + pool_size=self.INSTANCE_DB_POOL_SIZE, + max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="master", + connect_args={"connect_timeout": self.POSTGRES_CONNECT_TIMEOUT}, + ) + + @cached_property + def replica_engine(self) -> AsyncEngine | None: + if self.POSTGRES_RW_MODE != "replication": + return None + + return create_async_engine( + str(self.REPLICA_POSTGRES_URI), + pool_size=self.INSTANCE_DB_POOL_SIZE, + max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, + pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, + pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, + pool_pre_ping=False, + future=True, + echo=False, + logging_name="replica", + connect_args={ + "connect_timeout": self.POSTGRES_REPLICA_CONNECT_TIMEOUT, + }, + ) + VENDOR_NAME: ClassVar[str] = "MultiFactor" VENDOR_VERSION: str = Field( default_factory=_get_vendor_version, @@ -220,20 +276,6 @@ def check_certs_exist(self) -> bool: """Check if certs exist.""" return os.path.exists(self.SSL_CERT) and os.path.exists(self.SSL_KEY) - @cached_property - def engine(self) -> AsyncEngine: - """Get engine.""" - return create_async_engine( - str(self.POSTGRES_URI), - pool_size=self.INSTANCE_DB_POOL_SIZE, - max_overflow=self.INSTANCE_DB_POOL_OVERFLOW, - pool_timeout=self.INSTANCE_DB_POOL_TIMEOUT, - pool_recycle=self.INSTANCE_DB_POOL_RECYCLE, - pool_pre_ping=False, - future=True, - echo=False, - ) - @classmethod def from_os(cls) -> "Settings": """Get cls from environ.""" diff --git a/app/db_routing.py b/app/db_routing.py new file mode 100644 index 000000000..1e305b361 --- /dev/null +++ b/app/db_routing.py @@ -0,0 +1,75 @@ +"""Engine registry and routing session. + +Copyright (c) 2026 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from typing import Any, Sequence + +from sqlalchemy import Delete, Insert, Update, exc as sa_exc +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.orm import Session + + +class EngineRegistry: + _master_engine: AsyncEngine + _replica_engine: AsyncEngine | None + + def __init__( + self, + master_engine: AsyncEngine, + replica_engine: AsyncEngine | None, + ) -> None: + self._master_engine = master_engine + self._replica_engine = replica_engine + + def get_master_engine(self) -> AsyncEngine: + return self._master_engine + + def get_replica_engine(self) -> AsyncEngine: + if self._replica_engine is None: + raise RuntimeError("Replica engine is not configured") + return self._replica_engine + + def get_sync_master_engine(self) -> Engine: + return self._master_engine.sync_engine + + def get_sync_replica_engine(self) -> Engine: + if self._replica_engine is None: + raise RuntimeError("Replica engine is not configured") + return self._replica_engine.sync_engine + + +class RoutingSession(Session): + _force_master: bool = False + + @property + def engine_registry(self) -> EngineRegistry: + return self.info["engine_registry"] + + def set_force_master(self, value: bool) -> None: + self._force_master = value + + def get_bind(self, mapper=None, clause=None) -> Engine: # type: ignore # noqa: ARG002 + if isinstance(clause, Update | Insert | Delete): + self._force_master = True + return self.engine_registry.get_sync_master_engine() + + if self._force_master or self._flushing: + return self.engine_registry.get_sync_master_engine() + else: + return self.engine_registry.get_sync_replica_engine() + + def flush(self, objects: Sequence[Any] | None = None) -> None: + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + if self._is_clean(): + return + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + self._force_master = True diff --git a/app/ioc.py b/app/ioc.py index d6489f842..b1e4bd31a 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -8,12 +8,12 @@ import httpx import redis.asyncio as redis +from db_routing import EngineRegistry, RoutingSession from dishka import Provider, Scope, from_context, provide from fastapi import Request from loguru import logger from sqlalchemy.ext.asyncio import ( AsyncConnection, - AsyncEngine, AsyncSession, async_sessionmaker, ) @@ -163,17 +163,30 @@ class MainProvider(Provider): settings = from_context(provides=Settings, scope=Scope.APP) @provide(scope=Scope.APP) - def get_engine(self, settings: Settings) -> AsyncEngine: - """Get async engine.""" - return settings.engine + def get_engine_registry(self, settings: Settings) -> EngineRegistry: + return EngineRegistry( + master_engine=settings.engine, + replica_engine=settings.replica_engine, + ) @provide(scope=Scope.APP) def get_session_factory( self, - engine: AsyncEngine, + settings: Settings, + engine_registry: EngineRegistry, ) -> async_sessionmaker[AsyncSession]: """Create session factory.""" - return async_sessionmaker(engine, expire_on_commit=False) + if settings.POSTGRES_RW_MODE == "single": + return async_sessionmaker( + bind=engine_registry.get_master_engine(), + expire_on_commit=False, + ) + + return async_sessionmaker( + sync_session_class=RoutingSession, + expire_on_commit=False, + info={"engine_registry": engine_registry}, + ) @provide(scope=Scope.REQUEST) async def create_session( @@ -895,8 +908,9 @@ def get_session_factory( @provide(scope=Scope.APP) async def get_conn_factory( self, - engine: AsyncEngine, + engine_registry: EngineRegistry, ) -> AsyncIterator[AsyncConnection]: """Create session factory.""" + engine = engine_registry.get_master_engine() async with engine.connect() as connection: yield connection diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 75be3f6fc..80effc400 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -64,6 +64,7 @@ class AddRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = AddResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ADD CONTEXT_TYPE: ClassVar[type] = LDAPAddRequestContext diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index 445ce3bae..63667f034 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -18,11 +18,13 @@ from dishka import AsyncContainer from loguru import logger from pydantic import BaseModel +from sqlalchemy.exc import OperationalError from config import Settings from entities import Directory from ldap_protocol.dependency import resolve_deps from ldap_protocol.dialogue import LDAPSession +from ldap_protocol.ldap_codes import LDAPCodes from ldap_protocol.ldap_responses import BaseResponse, LDAPResult from ldap_protocol.objects import ProtocolRequests from ldap_protocol.policies.audit.audit_use_case import AuditUseCase @@ -63,6 +65,7 @@ class _APIProtocol: ... class BaseRequest(ABC, _APIProtocol, BaseModel): """Base request builder.""" + RESPONSE_TYPE: ClassVar[type] CONTEXT_TYPE: ClassVar[type] handle: ClassVar[handler] from_data: ClassVar[serializer] @@ -118,9 +121,17 @@ async def handle_tcp( ctx = await container.get(self.CONTEXT_TYPE) # type: ignore responses = [] - async for response in self.handle(ctx=ctx): - responses.append(response) - yield response + try: + async for response in self.handle(ctx=ctx): + responses.append(response) + yield response + except OperationalError: + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + yield self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ) + return if self.PROTOCOL_OP != ProtocolRequests.SEARCH: ldap_session = await container.get(LDAPSession) @@ -172,7 +183,17 @@ async def _handle_api( else: log_api.info(f"{get_class_name(self)}[{un}]") - responses = [response async for response in self.handle(ctx=ctx)] + try: + responses = [response async for response in self.handle(ctx=ctx)] + except OperationalError: + responses = [] + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + responses.append( + self.RESPONSE_TYPE( + result_code=LDAPCodes.UNAVAILABLE, + errorMessage="Master DB is not available", + ), + ) if settings.DEBUG: for response in responses: diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index 445b2f25c..bac4bb998 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, ClassVar from pydantic import Field +from sqlalchemy.exc import OperationalError from entities import NetworkPolicy from enums import MFAFlags @@ -42,6 +43,7 @@ class BindRequest(BaseRequest): """Bind request fields mapping.""" + RESPONSE_TYPE: ClassVar[type] = BindResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.BIND CONTEXT_TYPE: ClassVar[type] = LDAPBindRequestContext @@ -215,7 +217,12 @@ async def handle( ) await ctx.ldap_session.set_user(user) - await set_user_logon_attrs(user, ctx.session, ctx.settings.TIMEZONE) + with contextlib.suppress(OperationalError): + await set_user_logon_attrs( + user, + ctx.session, + ctx.settings.TIMEZONE, + ) server_sasl_creds = None if isinstance(self.authentication_choice, SaslSPNEGOAuthentication): diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index e2b127331..b69d981cb 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -42,6 +42,7 @@ class DeleteRequest(BaseRequest): DelRequest ::= [APPLICATION 10] LDAPDN """ + RESPONSE_TYPE: ClassVar[type] = DeleteResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.DELETE CONTEXT_TYPE: ClassVar[type] = LDAPDeleteRequestContext diff --git a/app/ldap_protocol/ldap_requests/extended.py b/app/ldap_protocol/ldap_requests/extended.py index c3967889e..b72cb7265 100644 --- a/app/ldap_protocol/ldap_requests/extended.py +++ b/app/ldap_protocol/ldap_requests/extended.py @@ -307,6 +307,7 @@ class ExtendedRequest(BaseRequest): requestValue [1] OCTET STRING OPTIONAL } """ + RESPONSE_TYPE: ClassVar[type] = ExtendedResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.EXTENDED CONTEXT_TYPE: ClassVar[type] = LDAPExtendedRequestContext request_name: LDAPOID diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 676550e3e..4ebaa3da6 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -94,6 +94,7 @@ class ModifyRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = ModifyResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY CONTEXT_TYPE: ClassVar[type] = LDAPModifyRequestContext diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index cdf03ab7b..dc4421d49 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -67,6 +67,7 @@ class ModifyDNRequest(BaseRequest): >>> cn = main2, cn = Users, dc = multifactor, dc = dev """ + RESPONSE_TYPE: ClassVar[type] = ModifyDNResponse PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY_DN CONTEXT_TYPE: ClassVar[type] = LDAPModifyDNRequestContext diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index c6505322a..d5ff4e679 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -104,6 +104,7 @@ class SearchRequest(BaseRequest): ``` """ + RESPONSE_TYPE: ClassVar[type] = SearchResultDone PROTOCOL_OP: ClassVar[int] = ProtocolRequests.SEARCH CONTEXT_TYPE: ClassVar[type] = LDAPSearchRequestContext diff --git a/app/ldap_protocol/session_storage/repository.py b/app/ldap_protocol/session_storage/repository.py index 84366faee..2e73dbc2d 100644 --- a/app/ldap_protocol/session_storage/repository.py +++ b/app/ldap_protocol/session_storage/repository.py @@ -1,9 +1,11 @@ """Enterprise Session Repository.""" +import contextlib from dataclasses import dataclass from ipaddress import IPv4Address, IPv6Address from typing import ClassVar, Literal +from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncSession from abstract_service import AbstractService @@ -87,8 +89,13 @@ async def create_session_key( }, ttl=ttl, ) + with contextlib.suppress(OperationalError): + await set_user_logon_attrs( + user, + self.session, + self.settings.TIMEZONE, + ) - await set_user_logon_attrs(user, self.session, self.settings.TIMEZONE) return key async def get_user_sessions(