diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py index 067ae1acf9f179479009ff9d1332345eb230177a..6690a80028d97aa183620ba5cc5ec622230bb7e4 100644 --- a/diaspora_auth_provider.py +++ b/diaspora_auth_provider.py @@ -40,9 +40,11 @@ class DiasporaAuthProvider: self.auth_handler = self.account_handler._auth_handler if self.config.engine == "mysql": import pymysql + self.module = pymysql - elif self.config.engine == 'postgres': + elif self.config.engine == "postgres": import psycopg2 + self.module = psycopg2 @defer.inlineCallbacks @@ -52,17 +54,17 @@ class DiasporaAuthProvider: user=self.config.db_username, password=self.config.db_password, host=self.config.db_host, - port=self.config.db_port + port=self.config.db_port, ) try: with self.connection: with self.connection.cursor() as cursor: - yield threads.deferToThread( # Don't think this is needed, but w/e - cursor.execute, query, args, - ) - results = yield threads.deferToThread( - cursor.fetchall + yield threads.deferToThread( # Don't think this is needed, but w/e + cursor.execute, + query, + args, ) + results = yield threads.deferToThread(cursor.fetchall) cursor.close() defer.returnValue(results) except self.module.Error as e: @@ -71,32 +73,41 @@ class DiasporaAuthProvider: finally: self.connection.close() - - @defer.inlineCallbacks def check_password(self, user_id, password): if not password: defer.returnValue(False) - local_part = user_id.split(':', 1)[0][1:] + local_part = user_id.split(":", 1)[0][1:] users = yield self.exec_query( - "SELECT username, encrypted_password, email FROM users WHERE username=%s", local_part - ) + "SELECT username, encrypted_password, email FROM users WHERE username=%s", + local_part, + ) # user_id is @localpart:hs_bare. we only need the localpart. logger.info("Checking if user {} exists.".format(local_part)) - + # check if the user exists. if not users: - logger.info("User {} does not exist. Rejecting auth request".format(local_part)) + logger.info( + "User {} does not exist. Rejecting auth request".format(local_part) + ) defer.returnValue(False) user = users[0] logger.debug("User {} exists. Checking password".format(local_part)) # user exists, check if the password is correct. encrypted_password = user[1] email = user[2] - peppered_pass = u"{}{}".format(password, self.config.pepper) - if not (bcrypt.hashpw(peppered_pass.encode('utf8'), encrypted_password.encode('utf8')) - == encrypted_password.encode('utf8')): - logger.info("Password given for {} is wrong. Rejecting auth request.".format(local_part)) + peppered_pass = "{}{}".format(password, self.config.pepper) + if not ( + bcrypt.hashpw( + peppered_pass.encode("utf8"), encrypted_password.encode("utf8") + ) + == encrypted_password.encode("utf8") + ): + logger.info( + "Password given for {} is wrong. Rejecting auth request.".format( + local_part + ) + ) defer.returnValue(False) self.register_user(local_part, email) logger.info("Confirming authentication request.") @@ -108,7 +119,9 @@ class DiasporaAuthProvider: if medium != "email": defer.returnValue(None) logger.debug("Searching for email {} in diaspora db.".format(address)) - users = yield self.exec_query("SELECT username FROM users WHERE email=%s", address) + users = yield self.exec_query( + "SELECT username FROM users WHERE email=%s", address + ) if not users: defer.returnValue(None) username = users[0][0] @@ -124,8 +137,8 @@ class DiasporaAuthProvider: yield self.sync_email(local_part, email) defer.returnValue(local_part) else: - user_id = ( - yield self.account_handler.register_user(localpart=local_part, emails=[email]) + user_id = yield self.account_handler.register_user( + localpart=local_part, emails=[email] ) defer.returnValue(user_id) @@ -133,19 +146,23 @@ class DiasporaAuthProvider: def sync_email(self, user_id, email): logger.info("Syncing emails of {}".format(user_id)) email = email.lower() - store = self.account_handler._store # Need access to datastore + store = self.account_handler._store # Need access to datastore threepids = yield store.user_get_threepids(user_id) if not threepids: logger.info("No 3pids found.") yield self.add_email(user_id, email) for threepid in threepids: - if not threepid['medium'] == 'email': + if not threepid["medium"] == "email": logger.debug("Not an email: {}".format(str(threepid))) pass - address = threepid['address'] + address = threepid["address"] if address != email: - logger.info("Existing 3pid doesn't match {} != {}. Deleting".format(address, email)) - yield self.auth_handler.delete_threepid(user_id, 'email', address) + logger.info( + "Existing 3pid doesn't match {} != {}. Deleting".format( + address, email + ) + ) + yield self.auth_handler.delete_threepid(user_id, "email", address) yield self.add_email(user_id, email) break logger.info("Sync completed.") @@ -154,19 +171,23 @@ class DiasporaAuthProvider: def add_email(self, user_id, email): logger.info("Adding 3pid: {} for {}".format(email, user_id)) validated_at = self.account_handler._hs.get_clock().time_msec() - yield self.auth_handler.add_threepid(user_id, 'email', email, validated_at) + yield self.auth_handler.add_threepid(user_id, "email", email, validated_at) @staticmethod def parse_config(config): class _Conf: pass + Conf = _Conf() - Conf.engine = config['database']['engine'] - Conf.db_name = "diaspora_production" if not config['database']['name'] else config['database']['name'] - Conf.db_host = config['database']['host'] - Conf.db_port = config['database']['port'] - Conf.db_username = config['database']['username'] - Conf.db_password = config['database']['password'] - Conf.pepper = config['pepper'] + Conf.engine = config["database"]["engine"] + Conf.db_name = ( + "diaspora_production" + if not config["database"]["name"] + else config["database"]["name"] + ) + Conf.db_host = config["database"]["host"] + Conf.db_port = config["database"]["port"] + Conf.db_username = config["database"]["username"] + Conf.db_password = config["database"]["password"] + Conf.pepper = config["pepper"] return Conf -