-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup_gateway.py
More file actions
264 lines (231 loc) · 8.48 KB
/
setup_gateway.py
File metadata and controls
264 lines (231 loc) · 8.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""Identity use cases.
Copyright (c) 2025 MultiFactor
License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE
"""
from ipaddress import IPv4Network
from itertools import chain
from loguru import logger
from sqlalchemy import exists, select
from sqlalchemy.ext.asyncio import AsyncSession
from entities import Attribute, Directory, Group, NetworkPolicy, User
from enums import SidPrefix
from ldap_protocol.ldap_schema.attribute_value_validator import (
AttributeValueValidator,
)
from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO
from ldap_protocol.rid_manager import ObjectSIDUseCase
from ldap_protocol.utils.async_cache import base_directories_cache
from ldap_protocol.utils.queries import get_domain_object_class
from password_utils import PasswordUtils
from repo.pg.tables import queryable_attr as qa
class SetupGateway:
"""Setup use case."""
def __init__(
self,
session: AsyncSession,
password_utils: PasswordUtils,
entity_type_dao: EntityTypeDAO,
attribute_value_validator: AttributeValueValidator,
object_sid_use_case: ObjectSIDUseCase,
) -> None:
"""Initialize Setup use case.
:param session: SQLAlchemy AsyncSession
return: None.
"""
self._session = session
self._password_utils = password_utils
self._entity_type_dao = entity_type_dao
self._attribute_value_validator = attribute_value_validator
self._object_sid_use_case = object_sid_use_case
async def is_setup(self) -> bool:
"""Check if setup is performed.
:return: bool (True if setup is performed, False otherwise)
"""
query = select(
exists(Directory)
.where(qa(Directory.parent_id).is_(None)),
) # fmt: skip
retval = await self._session.scalars(query)
return retval.one()
async def setup_enviroment(
self,
*,
data: list,
is_system: bool = True,
domain: Directory,
) -> None:
"""Create directories and users for enviroment."""
async with self._session.begin_nested():
self._session.add(domain)
self._session.add(
NetworkPolicy(
name="Default open policy",
netmasks=[IPv4Network("0.0.0.0/0")],
raw=["0.0.0.0/0"],
priority=1,
),
)
await self._session.flush()
await self._session.refresh(domain, ["id"])
self._session.add_all(list(get_domain_object_class(domain)))
await self._session.flush()
await self._session.refresh(
instance=domain,
attribute_names=["attributes"],
with_for_update=None,
)
await self._entity_type_dao.attach_entity_type_to_directory(
directory=domain,
is_system_entity_type=True,
)
if not self._attribute_value_validator.is_directory_valid(domain):
raise ValueError(
"Invalid directory attribute values during environment setup", # noqa: E501
)
await self._session.flush()
try:
for unit in data:
await self.create_dir(
unit,
is_system=is_system,
domain=domain,
parent=domain,
)
base_directories_cache.clear()
except Exception:
import traceback
logger.error(traceback.format_exc())
raise
async def is_base_domain_created(self) -> bool:
"""Check if base domain is created."""
cat_result = await self._session.execute(select(Directory))
if cat_result.scalar_one_or_none():
logger.warning("dev data already set up")
return True
return False
async def create_base_domain(
self,
dn: str = "multifactor.dev",
) -> Directory:
"""Create base domain."""
domain = Directory(name=dn, object_class="domain")
domain.is_system = True
domain.path = [f"dc={path}" for path in reversed(dn.split("."))]
domain.depth = len(domain.path)
domain.rdname = ""
self._session.add(domain)
await self._session.flush()
return domain
async def create_dir(
self,
data: dict,
is_system: bool,
domain: Directory,
parent: Directory | None = None,
) -> None:
"""Create data recursively."""
dir_ = Directory(
is_system=is_system,
object_class=data["object_class"],
name=data["name"],
)
dir_.groups = []
dir_.create_path(parent, dir_.get_dn_prefix())
self._session.add(dir_)
await self._session.flush()
dir_.parent_id = parent.id if parent else None
await self._session.refresh(dir_, ["id"])
self._session.add(
Attribute(
name=dir_.rdname,
value=dir_.name,
directory_id=dir_.id,
),
)
if "objectSid" in data:
await self._object_sid_use_case.add(
directory=dir_,
rid=int(data["objectSid"]),
sid_prefix=SidPrefix.BUILT_IN_DOMAIN,
)
if dir_.object_class == "group":
group = Group(directory_id=dir_.id)
self._session.add(group)
for group_name in data.get("groups", []):
parent_group = await self._get_group(group_name)
dir_.groups.append(parent_group)
await self._session.flush()
if "attributes" in data:
attrs = chain(
data["attributes"].items(),
[("objectClass", [dir_.object_class])],
)
for name, values in attrs:
for value in values:
self._session.add(
Attribute(
directory_id=dir_.id,
name=name,
value=value if isinstance(value, str) else None,
bvalue=value if isinstance(value, bytes) else None,
),
)
if "organizationalPerson" in data:
user_data = data["organizationalPerson"]
user = User(
directory_id=dir_.id,
sam_account_name=user_data["sam_account_name"],
user_principal_name=user_data["user_principal_name"],
display_name=user_data["display_name"],
mail=user_data["mail"],
password=self._password_utils.get_password_hash(
user_data["password"],
),
)
self._session.add(user)
await self._session.flush()
self._session.add(
Attribute(
directory_id=dir_.id,
name="homeDirectory",
value=f"/home/{user.uid}",
),
)
for group_name in user_data.get("groups", []):
parent_group = await self._get_group(group_name)
dir_.groups.append(parent_group)
await self._session.flush()
await self._session.refresh(
instance=dir_,
attribute_names=["attributes", "user"],
with_for_update=None,
)
await self._entity_type_dao.attach_entity_type_to_directory(
directory=dir_,
is_system_entity_type=True,
)
if not self._attribute_value_validator.is_directory_valid(dir_):
raise ValueError("Invalid directory attribute values")
await self._session.flush()
if "children" in data:
for n_data in data["children"]:
await self.create_dir(
n_data,
is_system=is_system,
domain=domain,
parent=dir_,
)
async def _get_group(self, name: str) -> Group:
"""Get group by name.
:param str name: group name
:return Group: group
"""
retval = await self._session.scalars(
select(Group)
.join(qa(Group.directory))
.filter(
qa(Directory.name) == name,
qa(Directory.object_class) == "group",
),
)
return retval.one()