From 844f68d7ecade0b812f4b692ee3a5ba8760d9576 Mon Sep 17 00:00:00 2001 From: Shamil K Muhammed <noteness@disroot.org> Date: Wed, 1 Nov 2017 21:50:55 +0530 Subject: [PATCH] Reconnect to DB for every auth request Some error with pymysql (perhaps?) causes poddery.com to only accept 1 or 2 auth requests before failing. This should fix that. (probably) --- diaspora_auth_provider.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py index 3e2af8c..143b2e5 100644 --- a/diaspora_auth_provider.py +++ b/diaspora_auth_provider.py @@ -22,13 +22,13 @@ import bcrypt import logging -__VERSION__ = "0.0.5" +__VERSION__ = "0.0.6" logger = logging.getLogger(__name__) class DiasporaAuthProvider: - __version__ = "0.0.5" + __version__ = "0.0.6" def __init__(self, config, account_handler): self.account_handler = account_handler @@ -39,6 +39,11 @@ class DiasporaAuthProvider: elif self.config.engine == 'postgres': import psycopg2 self.module = psycopg2 + + @defer.inlineCallbacks + def check_password(self, user_id, password): + if not password: + defer.returnValue(False) self.connection = self.module.connect( database=self.config.db_name, user=self.config.db_username, @@ -47,10 +52,6 @@ class DiasporaAuthProvider: port=self.config.db_port ) - @defer.inlineCallbacks - def check_password(self, user_id, password): - if not password: - defer.returnValue(False) # user_id is @localpart:hs_bare. we only need the localpart. local_part = user_id.split(':', 1)[0][1:] logger.info("Checking if user {} exists.".format(local_part)) @@ -65,6 +66,7 @@ class DiasporaAuthProvider: user = yield threads.deferToThread( cursor.fetchone ) + cursor.close() # check if the user exists. if not user: logger.info("User {} does not exist. Rejecting auth request".format(local_part)) @@ -95,6 +97,8 @@ class DiasporaAuthProvider: except self.module.Error as e: logger.warning("Error during diaspora authentication: {}".format(e)) defer.returnValue(False) + finally: + self.connection.close() @staticmethod def parse_config(config): -- GitLab