66
77"""
88
9+ import secrets
10+
911import sqlalchemy as sa
1012from alembic import op
1113from dishka import AsyncContainer , Scope
1618from enums import EntityTypeNames
1719from ldap_protocol .ldap_schema .dto import EntityTypeDTO
1820from ldap_protocol .ldap_schema .entity_type_use_case import EntityTypeUseCase
19- from ldap_protocol .rid_manager .exceptions import RIDManagerNotFoundError
20- from ldap_protocol .rid_manager .gateways import RIDManagerGateway
21- from ldap_protocol .rid_manager .use_cases import (
22- RID_AVAILABLE_MAX ,
21+ from ldap_protocol .rid_manager import (
22+ RIDManagerGateway ,
23+ RIDManagerSetupGateway ,
2324 RIDManagerSetupUseCase ,
25+ RIDManagerUseCase ,
26+ RIDSetUseCase ,
27+ )
28+ from ldap_protocol .rid_manager .dtos import RIDSetAllocationParamsDTO
29+ from ldap_protocol .rid_manager .exceptions import (
30+ RIDManagerNotFoundError ,
31+ RIDManagerRidSetNotFoundError ,
2432)
25- from ldap_protocol .rid_manager .utils import create_qword
33+ from ldap_protocol .rid_manager .rid_set_gateway import RIDSetGateway
34+ from ldap_protocol .rid_manager .utils import from_qword , to_qword
2635from ldap_protocol .utils .queries import get_base_directories
2736from repo .pg .tables import queryable_attr as qa
2837
2938# revision identifiers, used by Alembic.
3039revision : None | str = "552b4eafb1aa"
31- down_revision : None | str = "19d86e660cf2 "
40+ down_revision : None | str = "2dadf40c026a "
3241branch_labels : None | list [str ] = None
3342depends_on : None | list [str ] = None
3443
@@ -78,8 +87,8 @@ async def _migrate_object_sids(
7887 ) -> None :
7988 """Move Directory.objectSid values into Attributes table.
8089
81- Additionally, for domain directories move the domain SID prefix part
82- into the ``DomainIdentifier`` attribute .
90+ Additionally, for domain directories create the ``DomainIdentifier``
91+ attribute if it does not exist .
8392 """
8493 async with container (scope = Scope .REQUEST ) as cnt :
8594 session = await cnt .get (AsyncSession )
@@ -106,16 +115,46 @@ async def _migrate_object_sids(
106115 ),
107116 )
108117
109- if directory .name == "domain" :
110- identifier = directory .object_sid .split ("-" )[
111- - 1
112- ] # remove sid prefix
118+ base_dn_list = await get_base_directories (session )
119+ if base_dn_list :
120+ domain = base_dn_list [0 ]
121+
122+ existing_identifier = await session .scalar (
123+ select (Attribute ).where (
124+ qa (Attribute .directory_id ) == domain .id ,
125+ qa (Attribute .name ) == "DomainIdentifier" ,
126+ ),
127+ )
128+
129+ if not (existing_identifier and existing_identifier .value ):
130+ domain_object_sid = await session .scalar (
131+ select (Attribute ).where (
132+ qa (Attribute .directory_id ) == domain .id ,
133+ qa (Attribute .name ) == "objectSid" ,
134+ ),
135+ )
136+
137+ identifier : str | None = None
138+ if domain_object_sid and domain_object_sid .value :
139+ parts = domain_object_sid .value .split ("-" )
140+ # "S-1-5-21-AAA-BBB-CCC" -> "AAA-BBB-CCC"
141+ if len (parts ) >= 7 and domain_object_sid .value .startswith (
142+ "S-1-5-21-" ,
143+ ):
144+ identifier = "-" .join (parts [4 :7 ])
145+
146+ if identifier is None :
147+ identifier = (
148+ f"{ secrets .randbits (32 )} -"
149+ f"{ secrets .randbits (32 )} -"
150+ f"{ secrets .randbits (32 )} "
151+ )
113152
114153 session .add (
115154 Attribute (
116155 name = "DomainIdentifier" ,
117156 value = identifier ,
118- directory_id = directory .id ,
157+ directory_id = domain .id ,
119158 ),
120159 )
121160
@@ -129,27 +168,35 @@ async def _init_rid_manager(
129168 """Initialize RID Manager and RID Set for existing data."""
130169 async with container (scope = Scope .REQUEST ) as cnt :
131170 session = await cnt .get (AsyncSession )
132- rid_setup_use_case = await cnt .get (RIDManagerSetupUseCase )
133- rid_gateway = await cnt .get (RIDManagerGateway )
171+ rid_setup_gateway = await cnt .get (RIDManagerSetupGateway )
172+ rid_gateway = await cnt .get (RIDManagerGateway )
173+ rid_manager_use_case = await cnt .get (RIDManagerUseCase )
174+ rid_set_gateway = await cnt .get (RIDSetGateway )
175+ rid_set_use_case = await cnt .get (RIDSetUseCase )
134176
135177 if not await get_base_directories (session ):
136178 return
137179
138180 try :
139- await rid_gateway .get_rid_manager ()
181+ rid_manager_dir = await rid_gateway .get_rid_manager ()
140182 except RIDManagerNotFoundError :
141- await rid_setup_use_case .setup ()
142- await rid_gateway .get_rid_manager ()
183+ rid_manager_dir = await rid_setup_gateway .set_rid_manager ()
143184
144- rid_set_dir = await rid_gateway . get_rid_set ( )
145- if not rid_set_dir :
185+ base_dn_list = await get_base_directories ( session )
186+ if not base_dn_list :
146187 return
188+ domain = base_dn_list [0 ]
147189
148- base_domain = await rid_gateway .get_base_domain ()
149- domain_identifier = await rid_gateway .get_domain_identifier (
150- base_domain ,
190+ domain_identifier = await session .scalar (
191+ select (Attribute ).where (
192+ qa (Attribute .directory_id ) == domain .id ,
193+ qa (Attribute .name ) == "DomainIdentifier" ,
194+ ),
151195 )
152- sid_prefix = f"S-1-5-21-{ domain_identifier } -"
196+ if not (domain_identifier and domain_identifier .value ):
197+ return
198+
199+ sid_prefix = f"S-1-5-21-{ domain_identifier .value } -"
153200
154201 sid_values = await session .scalars (
155202 select (Attribute ).where (
@@ -172,25 +219,89 @@ async def _init_rid_manager(
172219
173220 start_rid = max (max_rid , RIDManagerSetupUseCase .RID_USER_MIN )
174221
175- qword = create_qword (start_rid , RID_AVAILABLE_MAX )
176- await rid_gateway . update_available_pool ( qword )
222+ qword = to_qword (start_rid , RIDManagerSetupUseCase . RID_AVAILABLE_MAX )
223+ await rid_setup_gateway . set_rid_available_pool ( rid_manager_dir , qword )
177224
178- result = await session .execute (
179- update (Attribute )
180- .where (
225+ domain_controller = await rid_gateway .get_domain_controller ()
226+ rid_set_dir : Directory | None = None
227+ try :
228+ rid_set_dir = await rid_set_gateway .get (domain_controller )
229+ except RIDManagerRidSetNotFoundError :
230+ rid_set_dir = None
231+
232+ if rid_set_dir is None :
233+ previous_allocation_pool = (
234+ await rid_manager_use_case .allocate_pool ()
235+ )
236+ allocation_pool = await rid_manager_use_case .allocate_pool ()
237+ lower , _ = from_qword (previous_allocation_pool )
238+
239+ await rid_set_use_case .add (
240+ domain_controller ,
241+ RIDSetAllocationParamsDTO (
242+ next_rid = lower ,
243+ allocation_pool = allocation_pool ,
244+ previous_allocation_pool = previous_allocation_pool ,
245+ ),
246+ )
247+ await session .commit ()
248+ return
249+
250+ existing_next_rid = await session .scalar (
251+ select (Attribute ).where (
181252 qa (Attribute .directory_id ) == rid_set_dir .id ,
182253 qa (Attribute .name ) == "rIDNextRID" ,
183- )
184- .values (value = str (start_rid )),
254+ ),
185255 )
186- if result .rowcount == 0 :
187- session .add (
188- Attribute (
189- directory_id = rid_set_dir .id ,
190- name = "rIDNextRID" ,
191- value = str (start_rid ),
192- ),
256+ existing_prev_pool = await session .scalar (
257+ select (Attribute ).where (
258+ qa (Attribute .directory_id ) == rid_set_dir .id ,
259+ qa (Attribute .name ) == "rIDPreviousAllocationPool" ,
260+ ),
261+ )
262+ existing_pool = await session .scalar (
263+ select (Attribute ).where (
264+ qa (Attribute .directory_id ) == rid_set_dir .id ,
265+ qa (Attribute .name ) == "rIDAllocationPool" ,
266+ ),
267+ )
268+
269+ if (
270+ existing_next_rid
271+ and existing_next_rid .value
272+ and existing_prev_pool
273+ and existing_prev_pool .value
274+ and existing_pool
275+ and existing_pool .value
276+ ):
277+ await session .commit ()
278+ return
279+
280+ previous_allocation_pool = await rid_manager_use_case .allocate_pool ()
281+ allocation_pool = await rid_manager_use_case .allocate_pool ()
282+ lower , _ = from_qword (previous_allocation_pool )
283+
284+ for name , value in (
285+ ("rIDNextRID" , str (lower )),
286+ ("rIDPreviousAllocationPool" , str (previous_allocation_pool )),
287+ ("rIDAllocationPool" , str (allocation_pool )),
288+ ):
289+ result = await session .execute (
290+ update (Attribute )
291+ .where (
292+ qa (Attribute .directory_id ) == rid_set_dir .id ,
293+ qa (Attribute .name ) == name ,
294+ )
295+ .values (value = value ),
193296 )
297+ if result .rowcount == 0 :
298+ session .add (
299+ Attribute (
300+ directory_id = rid_set_dir .id ,
301+ name = name ,
302+ value = value ,
303+ ),
304+ )
194305
195306 await session .commit ()
196307
0 commit comments