Skip to content
Snippets Groups Projects
Unverified Commit 844f68d7 authored by Shamil K Muhammed's avatar Shamil K Muhammed
Browse files

Reconnect to DB for every auth request

Some error with pymysql (perhaps?) causes poddery.com
to only accept 1 or 2 auth requests before failing.
This should fix that. (probably)
parent 3957ea7d
Branches
Tags
No related merge requests found
...@@ -22,13 +22,13 @@ import bcrypt ...@@ -22,13 +22,13 @@ import bcrypt
import logging import logging
__VERSION__ = "0.0.5" __VERSION__ = "0.0.6"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DiasporaAuthProvider: class DiasporaAuthProvider:
__version__ = "0.0.5" __version__ = "0.0.6"
def __init__(self, config, account_handler): def __init__(self, config, account_handler):
self.account_handler = account_handler self.account_handler = account_handler
...@@ -39,6 +39,11 @@ class DiasporaAuthProvider: ...@@ -39,6 +39,11 @@ class DiasporaAuthProvider:
elif self.config.engine == 'postgres': elif self.config.engine == 'postgres':
import psycopg2 import psycopg2
self.module = psycopg2 self.module = psycopg2
@defer.inlineCallbacks
def check_password(self, user_id, password):
if not password:
defer.returnValue(False)
self.connection = self.module.connect( self.connection = self.module.connect(
database=self.config.db_name, database=self.config.db_name,
user=self.config.db_username, user=self.config.db_username,
...@@ -47,10 +52,6 @@ class DiasporaAuthProvider: ...@@ -47,10 +52,6 @@ class DiasporaAuthProvider:
port=self.config.db_port port=self.config.db_port
) )
@defer.inlineCallbacks
def check_password(self, user_id, password):
if not password:
defer.returnValue(False)
# user_id is @localpart:hs_bare. we only need the localpart. # user_id is @localpart:hs_bare. we only need the localpart.
local_part = user_id.split(':', 1)[0][1:] local_part = user_id.split(':', 1)[0][1:]
logger.info("Checking if user {} exists.".format(local_part)) logger.info("Checking if user {} exists.".format(local_part))
...@@ -65,6 +66,7 @@ class DiasporaAuthProvider: ...@@ -65,6 +66,7 @@ class DiasporaAuthProvider:
user = yield threads.deferToThread( user = yield threads.deferToThread(
cursor.fetchone cursor.fetchone
) )
cursor.close()
# check if the user exists. # check if the user exists.
if not user: if not user:
logger.info("User {} does not exist. Rejecting auth request".format(local_part)) logger.info("User {} does not exist. Rejecting auth request".format(local_part))
...@@ -95,6 +97,8 @@ class DiasporaAuthProvider: ...@@ -95,6 +97,8 @@ class DiasporaAuthProvider:
except self.module.Error as e: except self.module.Error as e:
logger.warning("Error during diaspora authentication: {}".format(e)) logger.warning("Error during diaspora authentication: {}".format(e))
defer.returnValue(False) defer.returnValue(False)
finally:
self.connection.close()
@staticmethod @staticmethod
def parse_config(config): def parse_config(config):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment