Skip to content

Commit 58726a2

Browse files
author
m.shvets
committed
Refactor: Introduce ObjectSIDUseCase and related gateways, enhancing RID management functionality
1 parent 5383885 commit 58726a2

26 files changed

+1171
-804
lines changed

app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py

Lines changed: 149 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
77
"""
88

9+
import secrets
10+
911
import sqlalchemy as sa
1012
from alembic import op
1113
from dishka import AsyncContainer, Scope
@@ -16,19 +18,26 @@
1618
from enums import EntityTypeNames
1719
from ldap_protocol.ldap_schema.dto import EntityTypeDTO
1820
from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase
19-
from ldap_protocol.rid_manager.exceptions import RIDManagerNotFoundError
20-
from ldap_protocol.rid_manager.gateways import RIDManagerGateway
21-
from ldap_protocol.rid_manager.use_cases import (
22-
RID_AVAILABLE_MAX,
21+
from ldap_protocol.rid_manager import (
22+
RIDManagerGateway,
23+
RIDManagerSetupGateway,
2324
RIDManagerSetupUseCase,
25+
RIDManagerUseCase,
26+
RIDSetUseCase,
27+
)
28+
from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO
29+
from ldap_protocol.rid_manager.exceptions import (
30+
RIDManagerNotFoundError,
31+
RIDManagerRidSetNotFoundError,
2432
)
25-
from ldap_protocol.rid_manager.utils import create_qword
33+
from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway
34+
from ldap_protocol.rid_manager.utils import from_qword, to_qword
2635
from ldap_protocol.utils.queries import get_base_directories
2736
from repo.pg.tables import queryable_attr as qa
2837

2938
# revision identifiers, used by Alembic.
3039
revision: None | str = "552b4eafb1aa"
31-
down_revision: None | str = "19d86e660cf2"
40+
down_revision: None | str = "2dadf40c026a"
3241
branch_labels: None | list[str] = None
3342
depends_on: None | list[str] = None
3443

@@ -78,8 +87,8 @@ async def _migrate_object_sids(
7887
) -> None:
7988
"""Move Directory.objectSid values into Attributes table.
8089
81-
Additionally, for domain directories move the domain SID prefix part
82-
into the ``DomainIdentifier`` attribute.
90+
Additionally, for domain directories create the ``DomainIdentifier``
91+
attribute if it does not exist.
8392
"""
8493
async with container(scope=Scope.REQUEST) as cnt:
8594
session = await cnt.get(AsyncSession)
@@ -106,16 +115,46 @@ async def _migrate_object_sids(
106115
),
107116
)
108117

109-
if directory.name == "domain":
110-
identifier = directory.object_sid.split("-")[
111-
-1
112-
] # remove sid prefix
118+
base_dn_list = await get_base_directories(session)
119+
if base_dn_list:
120+
domain = base_dn_list[0]
121+
122+
existing_identifier = await session.scalar(
123+
select(Attribute).where(
124+
qa(Attribute.directory_id) == domain.id,
125+
qa(Attribute.name) == "DomainIdentifier",
126+
),
127+
)
128+
129+
if not (existing_identifier and existing_identifier.value):
130+
domain_object_sid = await session.scalar(
131+
select(Attribute).where(
132+
qa(Attribute.directory_id) == domain.id,
133+
qa(Attribute.name) == "objectSid",
134+
),
135+
)
136+
137+
identifier: str | None = None
138+
if domain_object_sid and domain_object_sid.value:
139+
parts = domain_object_sid.value.split("-")
140+
# "S-1-5-21-AAA-BBB-CCC" -> "AAA-BBB-CCC"
141+
if len(parts) >= 7 and domain_object_sid.value.startswith(
142+
"S-1-5-21-",
143+
):
144+
identifier = "-".join(parts[4:7])
145+
146+
if identifier is None:
147+
identifier = (
148+
f"{secrets.randbits(32)}-"
149+
f"{secrets.randbits(32)}-"
150+
f"{secrets.randbits(32)}"
151+
)
113152

114153
session.add(
115154
Attribute(
116155
name="DomainIdentifier",
117156
value=identifier,
118-
directory_id=directory.id,
157+
directory_id=domain.id,
119158
),
120159
)
121160

@@ -129,27 +168,35 @@ async def _init_rid_manager(
129168
"""Initialize RID Manager and RID Set for existing data."""
130169
async with container(scope=Scope.REQUEST) as cnt:
131170
session = await cnt.get(AsyncSession)
132-
rid_setup_use_case = await cnt.get(RIDManagerSetupUseCase)
133-
rid_gateway = await cnt.get(RIDManagerGateway)
171+
rid_setup_gateway = await cnt.get(RIDManagerSetupGateway)
172+
rid_gateway = await cnt.get(RIDManagerGateway)
173+
rid_manager_use_case = await cnt.get(RIDManagerUseCase)
174+
rid_set_gateway = await cnt.get(RIDSetGateway)
175+
rid_set_use_case = await cnt.get(RIDSetUseCase)
134176

135177
if not await get_base_directories(session):
136178
return
137179

138180
try:
139-
await rid_gateway.get_rid_manager()
181+
rid_manager_dir = await rid_gateway.get_rid_manager()
140182
except RIDManagerNotFoundError:
141-
await rid_setup_use_case.setup()
142-
await rid_gateway.get_rid_manager()
183+
rid_manager_dir = await rid_setup_gateway.set_rid_manager()
143184

144-
rid_set_dir = await rid_gateway.get_rid_set()
145-
if not rid_set_dir:
185+
base_dn_list = await get_base_directories(session)
186+
if not base_dn_list:
146187
return
188+
domain = base_dn_list[0]
147189

148-
base_domain = await rid_gateway.get_base_domain()
149-
domain_identifier = await rid_gateway.get_domain_identifier(
150-
base_domain,
190+
domain_identifier = await session.scalar(
191+
select(Attribute).where(
192+
qa(Attribute.directory_id) == domain.id,
193+
qa(Attribute.name) == "DomainIdentifier",
194+
),
151195
)
152-
sid_prefix = f"S-1-5-21-{domain_identifier}-"
196+
if not (domain_identifier and domain_identifier.value):
197+
return
198+
199+
sid_prefix = f"S-1-5-21-{domain_identifier.value}-"
153200

154201
sid_values = await session.scalars(
155202
select(Attribute).where(
@@ -172,25 +219,89 @@ async def _init_rid_manager(
172219

173220
start_rid = max(max_rid, RIDManagerSetupUseCase.RID_USER_MIN)
174221

175-
qword = create_qword(start_rid, RID_AVAILABLE_MAX)
176-
await rid_gateway.update_available_pool(qword)
222+
qword = to_qword(start_rid, RIDManagerSetupUseCase.RID_AVAILABLE_MAX)
223+
await rid_setup_gateway.set_rid_available_pool(rid_manager_dir, qword)
177224

178-
result = await session.execute(
179-
update(Attribute)
180-
.where(
225+
domain_controller = await rid_gateway.get_domain_controller()
226+
rid_set_dir: Directory | None = None
227+
try:
228+
rid_set_dir = await rid_set_gateway.get(domain_controller)
229+
except RIDManagerRidSetNotFoundError:
230+
rid_set_dir = None
231+
232+
if rid_set_dir is None:
233+
previous_allocation_pool = (
234+
await rid_manager_use_case.allocate_pool()
235+
)
236+
allocation_pool = await rid_manager_use_case.allocate_pool()
237+
lower, _ = from_qword(previous_allocation_pool)
238+
239+
await rid_set_use_case.add(
240+
domain_controller,
241+
RIDSetAllocationParamsDTO(
242+
next_rid=lower,
243+
allocation_pool=allocation_pool,
244+
previous_allocation_pool=previous_allocation_pool,
245+
),
246+
)
247+
await session.commit()
248+
return
249+
250+
existing_next_rid = await session.scalar(
251+
select(Attribute).where(
181252
qa(Attribute.directory_id) == rid_set_dir.id,
182253
qa(Attribute.name) == "rIDNextRID",
183-
)
184-
.values(value=str(start_rid)),
254+
),
185255
)
186-
if result.rowcount == 0:
187-
session.add(
188-
Attribute(
189-
directory_id=rid_set_dir.id,
190-
name="rIDNextRID",
191-
value=str(start_rid),
192-
),
256+
existing_prev_pool = await session.scalar(
257+
select(Attribute).where(
258+
qa(Attribute.directory_id) == rid_set_dir.id,
259+
qa(Attribute.name) == "rIDPreviousAllocationPool",
260+
),
261+
)
262+
existing_pool = await session.scalar(
263+
select(Attribute).where(
264+
qa(Attribute.directory_id) == rid_set_dir.id,
265+
qa(Attribute.name) == "rIDAllocationPool",
266+
),
267+
)
268+
269+
if (
270+
existing_next_rid
271+
and existing_next_rid.value
272+
and existing_prev_pool
273+
and existing_prev_pool.value
274+
and existing_pool
275+
and existing_pool.value
276+
):
277+
await session.commit()
278+
return
279+
280+
previous_allocation_pool = await rid_manager_use_case.allocate_pool()
281+
allocation_pool = await rid_manager_use_case.allocate_pool()
282+
lower, _ = from_qword(previous_allocation_pool)
283+
284+
for name, value in (
285+
("rIDNextRID", str(lower)),
286+
("rIDPreviousAllocationPool", str(previous_allocation_pool)),
287+
("rIDAllocationPool", str(allocation_pool)),
288+
):
289+
result = await session.execute(
290+
update(Attribute)
291+
.where(
292+
qa(Attribute.directory_id) == rid_set_dir.id,
293+
qa(Attribute.name) == name,
294+
)
295+
.values(value=value),
193296
)
297+
if result.rowcount == 0:
298+
session.add(
299+
Attribute(
300+
directory_id=rid_set_dir.id,
301+
name=name,
302+
value=value,
303+
),
304+
)
194305

195306
await session.commit()
196307

app/extra/scripts/add_domain_controller.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from enums import SamAccountTypeCodes, SecurityPrincipalRid
1515
from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO
1616
from ldap_protocol.objects import UserAccountControlFlag
17-
from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase
17+
from ldap_protocol.rid_manager import ObjectSIDUseCase
1818
from ldap_protocol.roles.role_use_case import RoleUseCase
1919
from repo.pg.tables import queryable_attr as qa
2020

@@ -25,7 +25,7 @@ async def _add_domain_controller(
2525
entity_type_dao: EntityTypeDAO,
2626
settings: Settings,
2727
dc_ou_dir: Directory,
28-
rid_manager_use_case: RIDManagerUseCase,
28+
object_sid_use_case: ObjectSIDUseCase,
2929
) -> None:
3030
dc_directory = Directory(
3131
object_class="",
@@ -37,7 +37,7 @@ async def _add_domain_controller(
3737
await session.flush()
3838

3939
dc_directory.parent_id = dc_ou_dir.id
40-
await rid_manager_use_case.set_object_sid(
40+
await object_sid_use_case.add(
4141
directory=dc_directory,
4242
rid=SecurityPrincipalRid.DOMAIN_CONTROLLERS,
4343
)
@@ -103,7 +103,7 @@ async def add_domain_controller(
103103
settings: Settings,
104104
role_use_case: RoleUseCase,
105105
entity_type_dao: EntityTypeDAO,
106-
rid_manager_use_case: RIDManagerUseCase,
106+
object_sid_use_case: ObjectSIDUseCase,
107107
) -> None:
108108
logger.info("Adding domain controller.")
109109

@@ -137,7 +137,7 @@ async def add_domain_controller(
137137
entity_type_dao=entity_type_dao,
138138
settings=settings,
139139
dc_ou_dir=domain_controllers_ou,
140-
rid_manager_use_case=rid_manager_use_case,
140+
object_sid_use_case=object_sid_use_case,
141141
)
142142

143143
logger.debug("Domain controller added.")

app/ioc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,14 @@
154154
UserPasswordHistoryUseCases,
155155
)
156156
from ldap_protocol.rid_manager import (
157+
ObjectSIDGateway,
158+
ObjectSIDUseCase,
157159
RIDManagerGateway,
158160
RIDManagerSetupGateway,
159161
RIDManagerSetupUseCase,
160162
RIDManagerUseCase,
163+
RIDSetGateway,
164+
RIDSetUseCase,
161165
)
162166
from ldap_protocol.roles.access_manager import AccessManager
163167
from ldap_protocol.roles.ace_dao import AccessControlEntryDAO
@@ -580,6 +584,10 @@ def get_dhcp_mngr(
580584
RIDManagerSetupUseCase,
581585
scope=Scope.REQUEST,
582586
)
587+
object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST)
588+
object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST)
589+
rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST)
590+
rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST)
583591

584592

585593
class LDAPContextProvider(Provider):

app/ldap_protocol/auth/setup_gateway.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AttributeValueValidator,
1818
)
1919
from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO
20-
from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase
20+
from ldap_protocol.rid_manager import ObjectSIDUseCase
2121
from ldap_protocol.utils.async_cache import base_directories_cache
2222
from ldap_protocol.utils.queries import get_domain_object_class
2323
from password_utils import PasswordUtils
@@ -33,7 +33,7 @@ def __init__(
3333
password_utils: PasswordUtils,
3434
entity_type_dao: EntityTypeDAO,
3535
attribute_value_validator: AttributeValueValidator,
36-
rid_manager_use_case: RIDManagerUseCase,
36+
object_sid_use_case: ObjectSIDUseCase,
3737
) -> None:
3838
"""Initialize Setup use case.
3939
@@ -45,7 +45,7 @@ def __init__(
4545
self._password_utils = password_utils
4646
self._entity_type_dao = entity_type_dao
4747
self._attribute_value_validator = attribute_value_validator
48-
self._rid_manager_use_case = rid_manager_use_case
48+
self._object_sid_use_case = object_sid_use_case
4949

5050
async def is_setup(self) -> bool:
5151
"""Check if setup is performed.
@@ -165,7 +165,7 @@ async def create_dir(
165165
)
166166

167167
if "objectSid" in data:
168-
await self._rid_manager_use_case.set_object_sid(
168+
await self._object_sid_use_case.add(
169169
directory=dir_,
170170
rid=int(data["objectSid"]),
171171
sid_prefix=SidPrefix.BUILT_IN_DOMAIN,

app/ldap_protocol/auth/use_cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ldap_protocol.objects import UserAccountControlFlag
2828
from ldap_protocol.policies.audit.audit_use_case import AuditUseCase
2929
from ldap_protocol.policies.password import PasswordPolicyUseCases
30-
from ldap_protocol.rid_manager.use_cases import RIDManagerSetupUseCase
30+
from ldap_protocol.rid_manager import RIDManagerSetupUseCase
3131
from ldap_protocol.roles.role_use_case import RoleUseCase
3232
from ldap_protocol.utils.helpers import create_integer_hash, ft_now
3333

app/ldap_protocol/ldap_requests/add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ async def handle( # noqa: C901
218218
ctx.session.add(new_dir)
219219

220220
await ctx.session.flush()
221-
await ctx.rid_manager_use_case.set_object_sid(
221+
await ctx.object_sid_use_case.add(
222222
directory=new_dir,
223223
)
224224
await ctx.session.flush()

0 commit comments

Comments
 (0)