From 3e219f5b59a4a04e4246d32753b95c8a5ee0ce30 Mon Sep 17 00:00:00 2001 From: Bady <bady@disroot.org> Date: Fri, 1 Nov 2024 16:32:12 +0000 Subject: [PATCH] Port module to the new interface Synapse introduced new generic modules system in v1.37.0. For more details see: - https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390 - https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html - https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html --- diaspora_auth_provider.py | 106 ++++++++++++++++++++++++-------------- 1 file changed, 67 insertions(+), 39 deletions(-) diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py index d845d3b..0bf5312 100644 --- a/diaspora_auth_provider.py +++ b/diaspora_auth_provider.py @@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. """ -from twisted.internet import defer, threads +from collections.abc import Awaitable, Callable +from typing import Any + +import asyncio + import synapse +from synapse import module_api import bcrypt import logging -from pkg_resources import parse_version - __VERSION__ = "0.2.2" @@ -34,53 +37,78 @@ logger = logging.getLogger(__name__) class DiasporaAuthProvider: __version__ = __VERSION__ - def __init__(self, config, account_handler): - self.account_handler = account_handler + def __init__(self, config: dict, api: module_api): self.config = config - self.auth_handler = self.account_handler._auth_handler + self.api = api if self.config.engine == "mysql": - import pymysql + import aiomysql - self.module = pymysql + self.module = aiomysql elif self.config.engine == "postgres": - import psycopg2 + import aiopg - self.module = psycopg2 + self.module = aiopg + + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("m.login.password", ("password",)): self.check_pass, + }, + ) - @defer.inlineCallbacks - def exec_query(self, query, *args): - self.connection = self.module.connect( - database=self.config.db_name, + async def exec_query( + self, + loop: asyncio.AbstractEventLoop, + query: str, + *args: list[str], + ) -> tuple[tuple[str]] | None: + pool = await self.module.create_pool( + db=self.config.db_name, user=self.config.db_username, password=self.config.db_password, host=self.config.db_host, port=self.config.db_port, + loop=loop, ) try: - with self.connection: - with self.connection.cursor() as cursor: - yield threads.deferToThread( # Don't think this is needed, but w/e - cursor.execute, - query, - args, - ) - results = yield threads.deferToThread(cursor.fetchall) - cursor.close() - defer.returnValue(results) + async with pool.acquire() as connection: + async with connection.cursor() as cursor: + await cursor.execute(query, args) + results = cursor.fetchall() + return results except self.module.Error as e: logger.warning("Error during execution of query: {}: {}".format(query, e)) - defer.returnValue(None) + return None finally: - self.connection.close() - - @defer.inlineCallbacks - def check_password(self, user_id, password): + pool.close() + await pool.wait_closed() + + async def check_pass( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> ( + tuple[ + str, + Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None, + ] + | None + ): + if login_type != "m.login.password": + return None + + password = login_dict.get("password") if not password: - defer.returnValue(False) - local_part = user_id.split(":", 1)[0][1:] - users = yield self.exec_query( - "SELECT username, encrypted_password, email FROM users WHERE username=%s", - local_part, + return None + + local_part = username.split(":", 1)[0][1:] if ":" in username else username + loop = asyncio.get_event_loop() + users = await loop.run_until_complete( + self.exec_query( + loop, + "SELECT username, encrypted_password, email FROM users WHERE username=%s", + local_part, + ) ) # user_id is @localpart:hs_bare. we only need the localpart. logger.info("Checking if user {} exists.".format(local_part)) @@ -90,10 +118,11 @@ class DiasporaAuthProvider: logger.info( "User {} does not exist. Rejecting auth request".format(local_part) ) - defer.returnValue(False) - user = users[0] + return None + logger.debug("User {} exists. Checking password".format(local_part)) # user exists, check if the password is correct. + user = users[0] encrypted_password = user[1] email = user[2] peppered_pass = "{}{}".format(password, self.config.pepper) @@ -108,9 +137,8 @@ class DiasporaAuthProvider: local_part ) ) - defer.returnValue(False) - logger.info("Confirming authentication request.") - defer.returnValue(True) + return None + return (self.api.get_qualified_user_id(username), None) @staticmethod def parse_config(config): -- GitLab