diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py index 6690a80028d97aa183620ba5cc5ec622230bb7e4..4a3c531fb43abd43a9bd0f6aa37c414b4097d868 100644 --- a/diaspora_auth_provider.py +++ b/diaspora_auth_provider.py @@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. """ -from twisted.internet import defer, threads +from collections.abc import Awaitable, Callable +from typing import Any + +import asyncio + import synapse +from synapse import module_api import bcrypt import logging -from pkg_resources import parse_version - __VERSION__ = "0.2.2" @@ -34,53 +37,78 @@ logger = logging.getLogger(__name__) class DiasporaAuthProvider: __version__ = __VERSION__ - def __init__(self, config, account_handler): - self.account_handler = account_handler + def __init__(self, config: dict, api: module_api): self.config = config - self.auth_handler = self.account_handler._auth_handler + self.api = api if self.config.engine == "mysql": - import pymysql + import aiomysql - self.module = pymysql + self.module = aiomysql elif self.config.engine == "postgres": - import psycopg2 + import aiopg - self.module = psycopg2 + self.module = aiopg - @defer.inlineCallbacks - def exec_query(self, query, *args): - self.connection = self.module.connect( - database=self.config.db_name, + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("m.login.password", ("password",)): self.check_pass, + }, + ) + + async def exec_query( + self, + loop: asyncio.AbstractEventLoop, + query: str, + *args: list[str], + ) -> tuple[tuple[str]] | None: + pool = await self.module.create_pool( + db=self.config.db_name, user=self.config.db_username, password=self.config.db_password, host=self.config.db_host, port=self.config.db_port, + loop=loop, ) try: - with self.connection: - with self.connection.cursor() as cursor: - yield threads.deferToThread( # Don't think this is needed, but w/e - cursor.execute, - query, - args, - ) - results = yield threads.deferToThread(cursor.fetchall) - cursor.close() - defer.returnValue(results) + async with pool.acquire() as connection: + async with connection.cursor() as cursor: + await cursor.execute(query, username) + results = await cursor.fetchall() + return results except self.module.Error as e: logger.warning("Error during execution of query: {}: {}".format(query, e)) - defer.returnValue(None) + return None finally: - self.connection.close() - - @defer.inlineCallbacks - def check_password(self, user_id, password): + pool.close() + await pool.wait_closed() + + async def check_pass( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> ( + tuple[ + str, + Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None, + ] + | None + ): + if login_type != "m.login.password": + return None + + password = login_dict.get("password") if not password: - defer.returnValue(False) - local_part = user_id.split(":", 1)[0][1:] - users = yield self.exec_query( - "SELECT username, encrypted_password, email FROM users WHERE username=%s", - local_part, + return None + + local_part = username.split(":", 1)[0][1:] if ":" in username else username + loop = asyncio.get_event_loop() + user = await self.loop.run_until_complete( + self.get_user( + loop, + "SELECT username, encrypted_password, email FROM users WHERE username=%s", + local_part, + ) ) # user_id is @localpart:hs_bare. we only need the localpart. logger.info("Checking if user {} exists.".format(local_part)) @@ -90,10 +118,11 @@ class DiasporaAuthProvider: logger.info( "User {} does not exist. Rejecting auth request".format(local_part) ) - defer.returnValue(False) - user = users[0] + return None + logger.debug("User {} exists. Checking password".format(local_part)) # user exists, check if the password is correct. + user = users[0] encrypted_password = user[1] email = user[2] peppered_pass = "{}{}".format(password, self.config.pepper) @@ -108,70 +137,8 @@ class DiasporaAuthProvider: local_part ) ) - defer.returnValue(False) - self.register_user(local_part, email) - logger.info("Confirming authentication request.") - defer.returnValue(True) - - @defer.inlineCallbacks - def check_3pid_auth(self, medium, address, password): - logger.info(medium, address, password) - if medium != "email": - defer.returnValue(None) - logger.debug("Searching for email {} in diaspora db.".format(address)) - users = yield self.exec_query( - "SELECT username FROM users WHERE email=%s", address - ) - if not users: - defer.returnValue(None) - username = users[0][0] - logger.debug("Found username! {}".format(username)) - logger.debug("Registering user {}".format(username)) - self.register_user(username, address) - logger.debug("Registration complete") - defer.returnValue(username) - - @defer.inlineCallbacks - def register_user(self, local_part, email): - if (yield self.account_handler.check_user_exists(local_part)): - yield self.sync_email(local_part, email) - defer.returnValue(local_part) - else: - user_id = yield self.account_handler.register_user( - localpart=local_part, emails=[email] - ) - defer.returnValue(user_id) - - @defer.inlineCallbacks - def sync_email(self, user_id, email): - logger.info("Syncing emails of {}".format(user_id)) - email = email.lower() - store = self.account_handler._store # Need access to datastore - threepids = yield store.user_get_threepids(user_id) - if not threepids: - logger.info("No 3pids found.") - yield self.add_email(user_id, email) - for threepid in threepids: - if not threepid["medium"] == "email": - logger.debug("Not an email: {}".format(str(threepid))) - pass - address = threepid["address"] - if address != email: - logger.info( - "Existing 3pid doesn't match {} != {}. Deleting".format( - address, email - ) - ) - yield self.auth_handler.delete_threepid(user_id, "email", address) - yield self.add_email(user_id, email) - break - logger.info("Sync completed.") - - @defer.inlineCallbacks - def add_email(self, user_id, email): - logger.info("Adding 3pid: {} for {}".format(email, user_id)) - validated_at = self.account_handler._hs.get_clock().time_msec() - yield self.auth_handler.add_threepid(user_id, "email", email, validated_at) + return None + return (self.api.get_qualified_user_id(username), None) @staticmethod def parse_config(config):