3
3
import os
4
4
from string import Template
5
5
6
- from ldap3 import Connection , SIMPLE
6
+ from ldap3 import Connection , SIMPLE , Server
7
7
from ldap3 .core .exceptions import LDAPAttributeError
8
8
from ldap3 .utils .conv import escape_filter_chars
9
9
@@ -39,7 +39,7 @@ def _resolve_base_dn(full_username):
39
39
return ''
40
40
41
41
42
- def _search (dn , search_request , attributes , connection ):
42
+ def _ldap_search (dn , search_request , attributes , connection ):
43
43
search_string = search_request .as_search_string ()
44
44
45
45
success = connection .search (dn , search_string , attributes = attributes )
@@ -53,7 +53,7 @@ def _search(dn, search_request, attributes, connection):
53
53
54
54
55
55
def _load_multiple_entries_values (dn , search_request , attribute_name , connection ):
56
- entries = _search (dn , search_request , [attribute_name ], connection )
56
+ entries = _ldap_search (dn , search_request , [attribute_name ], connection )
57
57
if entries is None :
58
58
return []
59
59
@@ -77,32 +77,25 @@ class LdapAuthenticator(auth_base.Authenticator):
77
77
def __init__ (self , params_dict , temp_folder ):
78
78
super ().__init__ ()
79
79
80
- self .url = model_helper .read_obligatory (params_dict , 'url' , ' for LDAP auth' )
80
+ self ._ldap_connector = LdapConnector (
81
+ model_helper .read_obligatory (params_dict , 'url' , ' for LDAP auth' ),
82
+ params_dict .get ('version' )
83
+ )
81
84
82
- username_pattern = strip (params_dict .get ('username_pattern' ))
83
- if username_pattern :
84
- self .username_template = Template (username_pattern )
85
- else :
86
- self .username_template = None
85
+ self ._ldap_user_resolver = LdapUserResolver (
86
+ params_dict .get ('ldap_user_resolver' ),
87
+ self ._ldap_connector )
87
88
88
89
base_dn = params_dict .get ('base_dn' )
89
90
if base_dn :
90
91
self ._base_dn = base_dn .strip ()
91
92
else :
92
- resolved_base_dn = _resolve_base_dn (username_pattern )
93
-
94
- if resolved_base_dn :
95
- LOGGER .info ('Resolved base dn: ' + resolved_base_dn )
96
- self ._base_dn = resolved_base_dn
97
- else :
93
+ self ._base_dn = self ._ldap_user_resolver .auto_resolve_base_dn ()
94
+ if not self ._base_dn :
98
95
LOGGER .warning (
99
96
'Cannot resolve LDAP base dn, so using empty. Please specify it using "base_dn" attribute' )
100
97
self ._base_dn = ''
101
98
102
- self .version = params_dict .get ("version" )
103
- if not self .version :
104
- self .version = 3
105
-
106
99
self ._groups_file = os .path .join (temp_folder , 'ldap_groups.json' )
107
100
self ._user_groups = self ._load_groups (self ._groups_file )
108
101
@@ -119,13 +112,10 @@ def perform_basic_auth(self, user, password):
119
112
def _authenticate_internal (self , username , password ):
120
113
LOGGER .info ('Logging in user ' + username )
121
114
122
- if self .username_template :
123
- full_username = self .username_template .substitute (username = username )
124
- else :
125
- full_username = username
115
+ full_username = self ._ldap_user_resolver .resolve_ldap_username (username , self ._base_dn )
126
116
127
117
try :
128
- connection = self ._connect (full_username , password )
118
+ connection = self ._ldap_connector . connect (full_username , password )
129
119
130
120
if connection .bound :
131
121
try :
@@ -155,18 +145,6 @@ def _authenticate_internal(self, username, password):
155
145
156
146
raise auth_base .AuthFailureError (error )
157
147
158
- def _connect (self , full_username , password ):
159
- connection = Connection (
160
- self .url ,
161
- user = full_username ,
162
- password = password ,
163
- authentication = SIMPLE ,
164
- read_only = True ,
165
- version = self .version
166
- )
167
- connection .bind ()
168
- return connection
169
-
170
148
def _get_groups (self , user ):
171
149
groups = self ._user_groups .get (user )
172
150
if groups is not None :
@@ -213,15 +191,8 @@ def _get_user_ids(self, full_username, connection):
213
191
LOGGER .warning ('Unsupported username pattern for ' + full_username )
214
192
return full_username , None
215
193
216
- entries = _search (base_dn , search_request , ['uid' ], connection )
217
- if not entries :
218
- return full_username , None
194
+ entry = LdapConnector .find_user (base_dn , search_request , connection )
219
195
220
- if len (entries ) > 1 :
221
- LOGGER .warning ('More than one user found by filter: ' + str (search_request ))
222
- return full_username , None
223
-
224
- entry = entries [0 ]
225
196
return get_entry_dn (entry ), entry .uid .value
226
197
227
198
def _load_groups (self , groups_file ):
@@ -248,3 +219,123 @@ def as_search_string(self):
248
219
249
220
def __str__ (self ) -> str :
250
221
return self .as_search_string ()
222
+
223
+
224
+ class LdapConnector :
225
+ def __init__ (self , url , version ):
226
+ self .url = url
227
+ self .version = version
228
+ if not self .version :
229
+ self .version = 3
230
+
231
+ def connect (self , full_username , password ):
232
+ server = Server (self .url , connect_timeout = 10 )
233
+ connection = Connection (
234
+ server ,
235
+ user = full_username ,
236
+ password = password ,
237
+ authentication = SIMPLE ,
238
+ read_only = True ,
239
+ version = self .version ,
240
+ )
241
+ connection .bind ()
242
+ return connection
243
+
244
+ @staticmethod
245
+ def find_user (base_dn , search_request , connection , attributes = None ):
246
+ if attributes is None :
247
+ attributes = ['uid' ]
248
+
249
+ entries = _ldap_search (base_dn , search_request , attributes , connection )
250
+ if not entries :
251
+ return None
252
+
253
+ if len (entries ) > 1 :
254
+ LOGGER .warning ('More than one user found by filter: ' + str (search_request ))
255
+ return None
256
+
257
+ return entries [0 ]
258
+
259
+
260
+ class LdapUserResolver :
261
+ def __init__ (self , config , ldap_connector : LdapConnector ) -> None :
262
+ self .username_template = None
263
+ self .username_pattern = None
264
+ self .search_by_attribute = None
265
+ self .admin_user = None
266
+ self .admin_password = None
267
+ self .ldap_connector = ldap_connector
268
+
269
+ if config :
270
+ username_pattern = strip (config .get ('username_pattern' ))
271
+ search_by_attribute = strip (config .get ('search_by_attribute' ))
272
+
273
+ # Validate that either username_pattern or search_by_attribute is specified
274
+ if not username_pattern and not search_by_attribute :
275
+ raise ValueError (
276
+ 'Either username_pattern or search_by_attribute must be specified in ldap_user_resolver.' )
277
+
278
+ if username_pattern and search_by_attribute :
279
+ raise ValueError (
280
+ 'Cannot specify both username_pattern and search_by_attribute in ldap_user_resolver. Choose one method.' )
281
+
282
+ if username_pattern :
283
+ self .username_template = Template (username_pattern )
284
+ self .username_pattern = username_pattern
285
+
286
+ if search_by_attribute :
287
+ self .search_by_attribute = search_by_attribute
288
+ self .admin_user = model_helper .read_obligatory (
289
+ config ,
290
+ 'admin_user' ,
291
+ ' for ldap_user_resolver with search_by_attribute'
292
+ )
293
+ self .admin_password = model_helper .read_obligatory (
294
+ config ,
295
+ 'admin_password' ,
296
+ ' for ldap_user_resolver with search_by_attribute'
297
+ )
298
+
299
+ def resolve_ldap_username (self , username , base_dn ):
300
+ if self .username_template :
301
+ return self .username_template .substitute (username = username )
302
+ elif self .search_by_attribute :
303
+ resolved_dn = self ._find_user_dn_by_attribute (username , base_dn )
304
+ return resolved_dn
305
+ else :
306
+ return username
307
+
308
+ def auto_resolve_base_dn (self ):
309
+ if self .username_pattern :
310
+ resolved_base_dn = _resolve_base_dn (self .username_pattern )
311
+ if resolved_base_dn :
312
+ LOGGER .info ('Resolved base dn: ' + resolved_base_dn )
313
+ return resolved_base_dn
314
+
315
+ if self .search_by_attribute :
316
+ resolved_base_dn = _resolve_base_dn (self .admin_user )
317
+ if not resolved_base_dn :
318
+ raise Exception ('"base_dn" is required for search_by_attribute user resolution' )
319
+ return resolved_base_dn
320
+
321
+ return None
322
+
323
+ def _find_user_dn_by_attribute (self , username , base_dn ):
324
+ admin_connection = self .ldap_connector .connect (self .admin_user , self .admin_password )
325
+
326
+ try :
327
+ if not admin_connection .bound :
328
+ error_msg = f'Failed to bind with admin LDAP user: { admin_connection .last_error } '
329
+ LOGGER .error (error_msg )
330
+ raise auth_base .AuthFailureError (error_msg )
331
+
332
+ search_request = SearchRequest (f'({ self .search_by_attribute } =%s)' , username )
333
+
334
+ user = self .ldap_connector .find_user (base_dn , search_request , admin_connection )
335
+ if user is None :
336
+ raise auth_base .AuthRejectedError ('Invalid credentials' )
337
+
338
+ return get_entry_dn (user )
339
+
340
+ finally :
341
+ admin_connection .unbind ()
0 commit comments