Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion src/auth/auth_ldap.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ def __init__(self, params_dict, temp_folder):
else:
self.username_template = None


username_attribute_pattern = strip(params_dict.get('username_attribute_pattern'))
if username_attribute_pattern:
self._username_attribute_pattern = username_attribute_pattern
else:
self._username_attribute_pattern = None

bind_dn = strip(params_dict.get('bind_dn'))
if bind_dn:
self._bind_dn = bind_dn
else:
self._bind_dn = None

bind_password = strip(params_dict.get('bind_password'))
if bind_password:
self._bind_passwd = bind_password
else:
self._bind_passwd = None

base_dn = params_dict.get('base_dn')
if base_dn:
self._base_dn = base_dn.strip()
Expand All @@ -110,12 +129,78 @@ def authenticate(self, request_handler):
username = request_handler.get_argument('username')
password = request_handler.get_argument('password')

return self._authenticate_internal(username, password)
# If the following conditions, depending on config settings, are all met,
# search for the dn by the given attribute.
if self._username_attribute_pattern and self._bind_dn and self._bind_passwd:
bind_connection = self._bind_user_connection(self._bind_dn, self._bind_passwd)
bind_dn = self._get_dn_from_user_attribute(bind_connection, username, self._username_attribute_pattern)
return self._authenticate_internal_by_bind(bind_dn, password, username)
else:
return self._authenticate_internal(username, password)

def perform_basic_auth(self, user, password):
self._authenticate_internal(user, password)
return True

def _bind_user_connection(self, username, password):
LOGGER.info('Logging in bind-user ' + username)

try:
connection = self._connect(username, password)

if connection.bound:
LOGGER.info('bind-user ' + username + ' logged in')
return connection

error = connection.last_error

except Exception as e:
error = str(e)

if error not in KNOWN_REJECTIONS:
LOGGER.exception('Error occurred while ldap authentication of user ' + username)

if error in KNOWN_REJECTIONS:
LOGGER.info('Invalid credentials for user ' + username)
raise auth_base.AuthRejectedError('Invalid credentials')

raise auth_base.AuthFailureError(error)

# This method uses an additional parameter
def _authenticate_internal_by_bind(self, username, password, username_searched):
LOGGER.info('Logging in user ' + username)

try:
connection = self._connect(username, password)

if connection.bound:
try:
user_dn, user_uid = self._get_user_ids(username, connection)
LOGGER.debug('user ids: ' + str((user_dn, user_uid)))

user_groups = self._fetch_user_groups(user_dn, user_uid, connection)
LOGGER.info('Loaded groups for ' + username + ': ' + str(user_groups))
self._set_user_groups(username_searched, user_groups)
except:
LOGGER.exception('Failed to load groups for the user ' + username)

connection.unbind()
return username_searched

error = connection.last_error

except Exception as e:
error = str(e)

if error not in KNOWN_REJECTIONS:
LOGGER.exception('Error occurred while ldap authentication of user ' + username)

if error in KNOWN_REJECTIONS:
LOGGER.info('Invalid credentials for user ' + username)
raise auth_base.AuthRejectedError('Invalid credentials')

raise auth_base.AuthFailureError(error)

def _authenticate_internal(self, username, password):
LOGGER.info('Logging in user ' + username)

Expand Down Expand Up @@ -167,6 +252,20 @@ def _connect(self, full_username, password):
connection.bind()
return connection

def _get_dn_from_user_attribute(self, connection, value, attribute):
_attr_template = "(" + attribute + "=%s)"
search_request = SearchRequest(_attr_template, value)
entries = _search(self._base_dn, search_request, ['*'], connection)
if not entries:
return None

if len(entries) > 1:
logging.warning('More than one user found by filter: ' + str(search_request))
return None

entry = entries[0]
return get_entry_dn(entry)

def _get_groups(self, user):
groups = self._user_groups.get(user)
if groups is not None:
Expand Down Expand Up @@ -236,6 +335,7 @@ def _set_user_groups(self, user, groups):

new_groups_content = json.dumps(self._user_groups, indent=2)
file_utils.write_file(self._groups_file, new_groups_content)
print('Gruppen', new_groups_content)


class SearchRequest:
Expand All @@ -248,3 +348,5 @@ def as_search_string(self):

def __str__(self) -> str:
return self.as_search_string()