diff --git a/src/auth/auth_ldap.py b/src/auth/auth_ldap.py index a0d42731..082cd6e3 100644 --- a/src/auth/auth_ldap.py +++ b/src/auth/auth_ldap.py @@ -85,6 +85,25 @@ def __init__(self, params_dict, temp_folder): else: self.username_template = None + + username_attribute_pattern = strip(params_dict.get('username_attribute_pattern')) + if username_attribute_pattern: + self._username_attribute_pattern = username_attribute_pattern + else: + self._username_attribute_pattern = None + + bind_dn = strip(params_dict.get('bind_dn')) + if bind_dn: + self._bind_dn = bind_dn + else: + self._bind_dn = None + + bind_password = strip(params_dict.get('bind_password')) + if bind_password: + self._bind_passwd = bind_password + else: + self._bind_passwd = None + base_dn = params_dict.get('base_dn') if base_dn: self._base_dn = base_dn.strip() @@ -110,12 +129,78 @@ def authenticate(self, request_handler): username = request_handler.get_argument('username') password = request_handler.get_argument('password') - return self._authenticate_internal(username, password) + # If the following conditions, depending on config settings, are all met, + # search for the dn by the given attribute. + if self._username_attribute_pattern and self._bind_dn and self._bind_passwd: + bind_connection = self._bind_user_connection(self._bind_dn, self._bind_passwd) + bind_dn = self._get_dn_from_user_attribute(bind_connection, username, self._username_attribute_pattern) + return self._authenticate_internal_by_bind(bind_dn, password, username) + else: + return self._authenticate_internal(username, password) def perform_basic_auth(self, user, password): self._authenticate_internal(user, password) return True + def _bind_user_connection(self, username, password): + LOGGER.info('Logging in bind-user ' + username) + + try: + connection = self._connect(username, password) + + if connection.bound: + LOGGER.info('bind-user ' + username + ' logged in') + return connection + + error = connection.last_error + + except Exception as e: + error = str(e) + + if error not in KNOWN_REJECTIONS: + LOGGER.exception('Error occurred while ldap authentication of user ' + username) + + if error in KNOWN_REJECTIONS: + LOGGER.info('Invalid credentials for user ' + username) + raise auth_base.AuthRejectedError('Invalid credentials') + + raise auth_base.AuthFailureError(error) + + # This method uses an additional parameter + def _authenticate_internal_by_bind(self, username, password, username_searched): + LOGGER.info('Logging in user ' + username) + + try: + connection = self._connect(username, password) + + if connection.bound: + try: + user_dn, user_uid = self._get_user_ids(username, connection) + LOGGER.debug('user ids: ' + str((user_dn, user_uid))) + + user_groups = self._fetch_user_groups(user_dn, user_uid, connection) + LOGGER.info('Loaded groups for ' + username + ': ' + str(user_groups)) + self._set_user_groups(username_searched, user_groups) + except: + LOGGER.exception('Failed to load groups for the user ' + username) + + connection.unbind() + return username_searched + + error = connection.last_error + + except Exception as e: + error = str(e) + + if error not in KNOWN_REJECTIONS: + LOGGER.exception('Error occurred while ldap authentication of user ' + username) + + if error in KNOWN_REJECTIONS: + LOGGER.info('Invalid credentials for user ' + username) + raise auth_base.AuthRejectedError('Invalid credentials') + + raise auth_base.AuthFailureError(error) + def _authenticate_internal(self, username, password): LOGGER.info('Logging in user ' + username) @@ -167,6 +252,20 @@ def _connect(self, full_username, password): connection.bind() return connection + def _get_dn_from_user_attribute(self, connection, value, attribute): + _attr_template = "(" + attribute + "=%s)" + search_request = SearchRequest(_attr_template, value) + entries = _search(self._base_dn, search_request, ['*'], connection) + if not entries: + return None + + if len(entries) > 1: + logging.warning('More than one user found by filter: ' + str(search_request)) + return None + + entry = entries[0] + return get_entry_dn(entry) + def _get_groups(self, user): groups = self._user_groups.get(user) if groups is not None: @@ -236,6 +335,7 @@ def _set_user_groups(self, user, groups): new_groups_content = json.dumps(self._user_groups, indent=2) file_utils.write_file(self._groups_file, new_groups_content) + print('Gruppen', new_groups_content) class SearchRequest: @@ -248,3 +348,5 @@ def as_search_string(self): def __str__(self) -> str: return self.as_search_string() + +