From a28ab2d5bfeba126a8483b19db6d08003f92c286 Mon Sep 17 00:00:00 2001 From: Shamil K <nessessery129@gmail.com> Date: Wed, 24 Jun 2020 21:42:02 +0530 Subject: [PATCH] Update the provider --- diaspora_auth_provider.py | 118 +++++++++++++++++++++++--------------- 1 file changed, 71 insertions(+), 47 deletions(-) diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py index 3811939..4f17281 100644 --- a/diaspora_auth_provider.py +++ b/diaspora_auth_provider.py @@ -17,18 +17,22 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. """ from twisted.internet import defer, threads +import synapse import bcrypt import logging -__VERSION__ = "0.1.1" +from pkg_resources import parse_version + + +__VERSION__ = "0.2.1" logger = logging.getLogger(__name__) class DiasporaAuthProvider: - __version__ = "0.1.1" + __version__ = __VERSION__ def __init__(self, config, account_handler): self.account_handler = account_handler @@ -42,9 +46,7 @@ class DiasporaAuthProvider: self.module = psycopg2 @defer.inlineCallbacks - def check_password(self, user_id, password): - if not password: - defer.returnValue(False) + def exec_query(self, query, *args): self.connection = self.module.connect( database=self.config.db_name, user=self.config.db_username, @@ -52,58 +54,80 @@ class DiasporaAuthProvider: host=self.config.db_host, port=self.config.db_port ) - - # user_id is @localpart:hs_bare. we only need the localpart. - local_part = user_id.split(':', 1)[0][1:] - logger.info("Checking if user {} exists.".format(local_part)) try: with self.connection: with self.connection.cursor() as cursor: yield threads.deferToThread( # Don't think this is needed, but w/e - cursor.execute, - "SELECT username, encrypted_password, email FROM users WHERE username=%s", - (local_part,) + cursor.execute, query, args, ) - user = yield threads.deferToThread( - cursor.fetchone + results = yield threads.deferToThread( + cursor.fetchall ) cursor.close() - # check if the user exists. - if not user: - logger.info("User {} does not exist. Rejecting auth request".format(local_part)) - defer.returnValue(False) - logger.debug("User {} exists. Checking password".format(local_part)) - # user exists, check if the password is correct. - encrypted_password = user[1] - email = user[2] - peppered_pass = u"{}{}".format(password, self.config.pepper) - if not (bcrypt.hashpw(peppered_pass.encode('utf8'), encrypted_password.encode('utf8')) - == encrypted_password.encode('utf8')): - logger.info("Password given for {} is wrong. Rejecting auth request.".format(local_part)) - defer.returnValue(False) - # Ok, user's password is correct. check if the user exists in the homeserver db. - # and create it if doesn't exist. - if (yield self.account_handler.check_user_exists(user_id)): - logger.info("User {} does exist in synapse db. Authentication complete".format(local_part)) - yield self.sync_email(user_id, email) - defer.returnValue(True) - # User not in synapse db. need to create it. - logger.info("User {} does not exist in synapse db. creating it.".format(local_part)) - user_id, access_token = ( - yield self.account_handler.register(localpart=local_part) - ) - logger.info( - "Registration based on diaspora complete. UserID: {}.".format(user_id) - ) - logger.info("Confirming authentication request.") - yield self.sync_email(user_id, email) - defer.returnValue(True) + defer.returnValue(results) except self.module.Error as e: - logger.warning("Error during diaspora authentication: {}".format(e)) - defer.returnValue(False) + logger.warning("Error during execution of query: {}: {}".format(query, e)) + defer.returnValue(None) finally: self.connection.close() + + + @defer.inlineCallbacks + def check_password(self, user_id, password): + if not password: + defer.returnValue(False) + local_part = user_id.split(':', 1)[0][1:] + users = yield self.exec_query( + "SELECT username, encrypted_password, email FROM users WHERE username=%s", local_part + ) + # user_id is @localpart:hs_bare. we only need the localpart. + logger.info("Checking if user {} exists.".format(local_part)) + + # check if the user exists. + if not users: + logger.info("User {} does not exist. Rejecting auth request".format(local_part)) + defer.returnValue(False) + user = users[0] + logger.debug("User {} exists. Checking password".format(local_part)) + # user exists, check if the password is correct. + encrypted_password = user[1] + email = user[2] + peppered_pass = u"{}{}".format(password, self.config.pepper) + if not (bcrypt.hashpw(peppered_pass.encode('utf8'), encrypted_password.encode('utf8')) + == encrypted_password.encode('utf8')): + logger.info("Password given for {} is wrong. Rejecting auth request.".format(local_part)) + defer.returnValue(False) + yield self.register_user(local_part, email) + logger.info("Confirming authentication request.") + defer.returnValue(True) + + @defer.inlineCallbacks + def check_3pid_auth(self, medium, address, password): + logger.info(medium, address, password) + if medium != "email": + defer.returnValue(None) + logger.debug("Searching for email {} in diaspora db.".format(address)) + users = yield self.exec_query("SELECT username FROM users WHERE email=%s", address) + if not users: + defer.returnValue(None) + username = users[0][0] + logger.debug("Found username! {}".format(username)) + logger.debug("Registering user {}".format(username)) + self.register_user(username, address) + logger.debug("Registration complete") + defer.returnValue(username) + + @defer.inlineCallbacks + def register_user(self, local_part, email): + if (yield self.account_handler.check_user_exists(local_part)): + yield self.sync_email(local_part, email) + defer.returnValue(local_part) + user_id = ( + yield self.account_handler.register_user(localpart=local_part, emails=[email]) + ) + defer.returnValue(user_id) + @defer.inlineCallbacks def sync_email(self, user_id, email): logger.info("Syncing emails of {}".format(user_id)) @@ -128,7 +152,7 @@ class DiasporaAuthProvider: @defer.inlineCallbacks def add_email(self, user_id, email): logger.info("Adding 3pid: {} for {}".format(email, user_id)) - validated_at = self.account_handler.hs.get_clock().time_msec() + validated_at = self.account_handler._hs.get_clock().time_msec() yield self.auth_handler.add_threepid(user_id, 'email', email, validated_at) @staticmethod -- GitLab