From 0cbb3fa1d67ca1f620e802d7eedbc20945f8ba65 Mon Sep 17 00:00:00 2001 From: Bady <bady@disroot.org> Date: Thu, 31 Oct 2024 14:20:18 +0000 Subject: [PATCH] Port module to the new interface Synapse introduced new generic modules system in v1.37.0. For more details see: - https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390 - https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html - https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html This commit only updates the code needed for login. Once it is tested and found working, the code needed to register users which used to work with the legacy module api also needs to be updated. --- diaspora_auth_provider.py | 150 +++++++++++--------------------------- 1 file changed, 44 insertions(+), 106 deletions(-) diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py index 6690a80..e72b62a 100644 --- a/diaspora_auth_provider.py +++ b/diaspora_auth_provider.py @@ -16,17 +16,17 @@ You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. """ -from twisted.internet import defer, threads +from typing import Any, Awaitable, Callable, Optional, Tuple + import synapse +from synapse import module_api import bcrypt import logging -from pkg_resources import parse_version - -__VERSION__ = "0.2.2" +__VERSION__ = "0.4.0" logger = logging.getLogger(__name__) @@ -34,10 +34,9 @@ logger = logging.getLogger(__name__) class DiasporaAuthProvider: __version__ = __VERSION__ - def __init__(self, config, account_handler): - self.account_handler = account_handler + def __init__(self, config: dict, api: module_api): + self.api = api self.config = config - self.auth_handler = self.account_handler._auth_handler if self.config.engine == "mysql": import pymysql @@ -47,8 +46,17 @@ class DiasporaAuthProvider: self.module = psycopg2 - @defer.inlineCallbacks - def exec_query(self, query, *args): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("m.login.password", ("password",)): self.check_password, + }, + ) + + async def exec_query( + self, + query: str, + *args: str, + ) -> Optional[Tuple[Tuple[Any]]]: self.connection = self.module.connect( database=self.config.db_name, user=self.config.db_username, @@ -59,26 +67,36 @@ class DiasporaAuthProvider: try: with self.connection: with self.connection.cursor() as cursor: - yield threads.deferToThread( # Don't think this is needed, but w/e - cursor.execute, - query, - args, - ) - results = yield threads.deferToThread(cursor.fetchall) + cursor.execute(query, args) + results = cursor.fetchall() cursor.close() - defer.returnValue(results) + return results except self.module.Error as e: logger.warning("Error during execution of query: {}: {}".format(query, e)) - defer.returnValue(None) + return None finally: self.connection.close() - @defer.inlineCallbacks - def check_password(self, user_id, password): + async def check_password( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "m.login.password": + return None + + password = login_dict.get("password") if not password: - defer.returnValue(False) - local_part = user_id.split(":", 1)[0][1:] - users = yield self.exec_query( + return None + + local_part = username.split(":", 1)[0][1:] if ":" in username else username + users = await self.exec_query( "SELECT username, encrypted_password, email FROM users WHERE username=%s", local_part, ) @@ -90,7 +108,8 @@ class DiasporaAuthProvider: logger.info( "User {} does not exist. Rejecting auth request".format(local_part) ) - defer.returnValue(False) + return None + user = users[0] logger.debug("User {} exists. Checking password".format(local_part)) # user exists, check if the password is correct. @@ -108,86 +127,5 @@ class DiasporaAuthProvider: local_part ) ) - defer.returnValue(False) - self.register_user(local_part, email) - logger.info("Confirming authentication request.") - defer.returnValue(True) - - @defer.inlineCallbacks - def check_3pid_auth(self, medium, address, password): - logger.info(medium, address, password) - if medium != "email": - defer.returnValue(None) - logger.debug("Searching for email {} in diaspora db.".format(address)) - users = yield self.exec_query( - "SELECT username FROM users WHERE email=%s", address - ) - if not users: - defer.returnValue(None) - username = users[0][0] - logger.debug("Found username! {}".format(username)) - logger.debug("Registering user {}".format(username)) - self.register_user(username, address) - logger.debug("Registration complete") - defer.returnValue(username) - - @defer.inlineCallbacks - def register_user(self, local_part, email): - if (yield self.account_handler.check_user_exists(local_part)): - yield self.sync_email(local_part, email) - defer.returnValue(local_part) - else: - user_id = yield self.account_handler.register_user( - localpart=local_part, emails=[email] - ) - defer.returnValue(user_id) - - @defer.inlineCallbacks - def sync_email(self, user_id, email): - logger.info("Syncing emails of {}".format(user_id)) - email = email.lower() - store = self.account_handler._store # Need access to datastore - threepids = yield store.user_get_threepids(user_id) - if not threepids: - logger.info("No 3pids found.") - yield self.add_email(user_id, email) - for threepid in threepids: - if not threepid["medium"] == "email": - logger.debug("Not an email: {}".format(str(threepid))) - pass - address = threepid["address"] - if address != email: - logger.info( - "Existing 3pid doesn't match {} != {}. Deleting".format( - address, email - ) - ) - yield self.auth_handler.delete_threepid(user_id, "email", address) - yield self.add_email(user_id, email) - break - logger.info("Sync completed.") - - @defer.inlineCallbacks - def add_email(self, user_id, email): - logger.info("Adding 3pid: {} for {}".format(email, user_id)) - validated_at = self.account_handler._hs.get_clock().time_msec() - yield self.auth_handler.add_threepid(user_id, "email", email, validated_at) - - @staticmethod - def parse_config(config): - class _Conf: - pass - - Conf = _Conf() - Conf.engine = config["database"]["engine"] - Conf.db_name = ( - "diaspora_production" - if not config["database"]["name"] - else config["database"]["name"] - ) - Conf.db_host = config["database"]["host"] - Conf.db_port = config["database"]["port"] - Conf.db_username = config["database"]["username"] - Conf.db_password = config["database"]["password"] - Conf.pepper = config["pepper"] - return Conf + return None + return (self.api.get_qualified_user_id(username), None) -- GitLab