Skip to content

Commit 99d45e9

Browse files
Refactor: RenameRequest entry (copilot fixes) (#923)
1 parent 2df633a commit 99d45e9

1 file changed

Lines changed: 26 additions & 10 deletions

File tree

app/ldap_protocol/ldap_requests/modify.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import AsyncGenerator, ClassVar
99

1010
from loguru import logger
11+
from pydantic import Field
1112
from sqlalchemy import Select, and_, delete, func, or_, select, update
1213
from sqlalchemy.exc import IntegrityError
1314
from sqlalchemy.ext.asyncio import AsyncSession
@@ -110,7 +111,7 @@ class ModifyRequest(BaseRequest):
110111
# NOTE: If the old value was changed (for example, in _delete)
111112
# in one method, then you need to have access to the old value
112113
# from other methods (for example, from _add)
113-
_old_vals: dict[str, str | None] = {}
114+
old_vals: dict[str, str | None] = Field(default_factory=dict)
114115

115116
@classmethod
116117
def from_data(cls, data: list[ASN1Row]) -> "ModifyRequest":
@@ -143,7 +144,7 @@ async def _update_password_expiration(
143144
return
144145

145146
if not (
146-
change.modification.type == "krbpasswordexpiration"
147+
change.l_type == "krbpasswordexpiration"
147148
and change.modification.vals[0] == "19700101000000Z"
148149
):
149150
return
@@ -284,10 +285,10 @@ async def handle(
284285

285286
except MODIFY_EXCEPTION_STACK as err:
286287
await ctx.session.rollback()
287-
result_code, message = self._match_bad_response(err)
288+
result_code, error_message = self._match_bad_response(err)
288289
yield ModifyResponse(
289290
result_code=result_code,
290-
message=message,
291+
error_message=error_message,
291292
)
292293
return
293294

@@ -333,6 +334,9 @@ def _match_bad_response(self, err: BaseException) -> tuple[LDAPCodes, str]:
333334
case ModifyForbiddenError():
334335
return LDAPCodes.OPERATIONS_ERROR, str(err)
335336

337+
case KRBAPIRenamePrincipalError():
338+
return LDAPCodes.UNAVAILABLE, "Kerberos error"
339+
336340
case KRBAPIPrincipalNotFoundError():
337341
return LDAPCodes.UNAVAILABLE, "Kerberos error"
338342

@@ -632,8 +636,8 @@ def _need_to_cache_samaccountname_old_value(
632636
return bool(
633637
directory.entity_type
634638
and directory.entity_type.name == EntityTypeNames.COMPUTER
635-
and change.modification.type == "sAMAccountName"
636-
and not self._old_vals.get(change.modification.type),
639+
and change.l_type == "samaccountname"
640+
and not self.old_vals.get(change.modification.type),
637641
)
638642

639643
async def _delete(
@@ -689,7 +693,7 @@ async def _delete(
689693
if self._need_to_cache_samaccountname_old_value(change, directory):
690694
vals = directory.attributes_dict.get(change.modification.type)
691695
if vals:
692-
self._old_vals[change.modification.type] = vals[0]
696+
self.old_vals[change.modification.type] = vals[0]
693697

694698
if attrs:
695699
del_query = (
@@ -826,14 +830,13 @@ async def _add( # noqa: C901
826830
password_use_cases: PasswordPolicyUseCases,
827831
password_utils: PasswordUtils,
828832
) -> None:
833+
base_dir = None
829834
attrs = []
830835

831836
if change.l_type in ("memberof", "member", "primarygroupid"):
832837
await self._add_group_attrs(change, directory, session)
833838
return
834839

835-
base_dir = await self._get_base_dir(directory, session)
836-
837840
for value in change.modification.vals:
838841
if change.l_type == "useraccountcontrol":
839842
uac_val = int(value)
@@ -923,6 +926,12 @@ async def _add( # noqa: C901
923926
new_user_principal_name = str(new_value)
924927
new_sam_account_name = new_user_principal_name.split("@")[0] # noqa: E501 # fmt: skip
925928
elif change.l_type == "samaccountname":
929+
if not base_dir:
930+
base_dir = await self._get_base_dir(
931+
directory,
932+
session,
933+
)
934+
926935
new_sam_account_name = str(new_value)
927936
new_user_principal_name = f"{new_sam_account_name}@{base_dir.name}" # noqa: E501 # fmt: skip
928937

@@ -946,12 +955,19 @@ async def _add( # noqa: C901
946955
and directory.entity_type
947956
and directory.entity_type.name == EntityTypeNames.COMPUTER
948957
):
958+
if not base_dir:
959+
base_dir = await self._get_base_dir(
960+
directory,
961+
session,
962+
)
963+
949964
await self._modify_computer_samaccountname(
950965
change,
951966
kadmin,
952967
base_dir,
953968
value,
954969
)
970+
955971
attrs.append(
956972
Attribute(
957973
name=change.modification.type,
@@ -1019,7 +1035,7 @@ async def _modify_computer_samaccountname(
10191035
base_dir: Directory,
10201036
new_sam_account_name: bytes | str,
10211037
) -> None:
1022-
old_sam_account_name = self._old_vals.get(change.modification.type)
1038+
old_sam_account_name = self.old_vals.get(change.modification.type)
10231039
new_sam_account_name = str(new_sam_account_name)
10241040

10251041
if not old_sam_account_name:

0 commit comments

Comments
 (0)