From df1a9fda8d3ebb6f40d11ffada5d0afe9b3b9034 Mon Sep 17 00:00:00 2001
From: Bady <bady@disroot.org>
Date: Thu, 31 Oct 2024 14:20:18 +0000
Subject: [PATCH] Port module to the new interface

Synapse introduced new generic modules system in v1.37.0. For more details see:
- https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390
- https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html
- https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html

This commit only updates the code needed for login. Once it is tested
and found working, the code needed to register users which used to work
with the legacy module api also needs to be updated.
---
 diaspora_auth_provider.py | 167 +++++++++++++++-----------------------
 1 file changed, 67 insertions(+), 100 deletions(-)

diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py
index 6690a80..4a3c531 100644
--- a/diaspora_auth_provider.py
+++ b/diaspora_auth_provider.py
@@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
 """
 
-from twisted.internet import defer, threads
+from collections.abc import Awaitable, Callable
+from typing import Any
+
+import asyncio
+
 import synapse
+from synapse import module_api
 
 import bcrypt
 
 import logging
 
-from pkg_resources import parse_version
-
 
 __VERSION__ = "0.2.2"
 
@@ -34,53 +37,78 @@ logger = logging.getLogger(__name__)
 class DiasporaAuthProvider:
     __version__ = __VERSION__
 
-    def __init__(self, config, account_handler):
-        self.account_handler = account_handler
+    def __init__(self, config: dict, api: module_api):
         self.config = config
-        self.auth_handler = self.account_handler._auth_handler
+        self.api = api
         if self.config.engine == "mysql":
-            import pymysql
+            import aiomysql
 
-            self.module = pymysql
+            self.module = aiomysql
         elif self.config.engine == "postgres":
-            import psycopg2
+            import aiopg
 
-            self.module = psycopg2
+            self.module = aiopg
 
-    @defer.inlineCallbacks
-    def exec_query(self, query, *args):
-        self.connection = self.module.connect(
-            database=self.config.db_name,
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("m.login.password", ("password",)): self.check_pass,
+            },
+        )
+
+    async def exec_query(
+        self,
+        loop: asyncio.AbstractEventLoop,
+        query: str,
+        *args: list[str],
+    ) -> tuple[tuple[str]] | None:
+        pool = await self.module.create_pool(
+            db=self.config.db_name,
             user=self.config.db_username,
             password=self.config.db_password,
             host=self.config.db_host,
             port=self.config.db_port,
+            loop=loop,
         )
         try:
-            with self.connection:
-                with self.connection.cursor() as cursor:
-                    yield threads.deferToThread(  # Don't think this is needed, but w/e
-                        cursor.execute,
-                        query,
-                        args,
-                    )
-                    results = yield threads.deferToThread(cursor.fetchall)
-                    cursor.close()
-            defer.returnValue(results)
+            async with pool.acquire() as connection:
+                async with connection.cursor() as cursor:
+                    await cursor.execute(query, username)
+                    results = await cursor.fetchall()
+            return results
         except self.module.Error as e:
             logger.warning("Error during execution of query: {}: {}".format(query, e))
-            defer.returnValue(None)
+            return None
         finally:
-            self.connection.close()
-
-    @defer.inlineCallbacks
-    def check_password(self, user_id, password):
+            pool.close()
+            await pool.wait_closed()
+
+    async def check_pass(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> (
+        tuple[
+            str,
+            Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None,
+        ]
+        | None
+    ):
+        if login_type != "m.login.password":
+            return None
+
+        password = login_dict.get("password")
         if not password:
-            defer.returnValue(False)
-        local_part = user_id.split(":", 1)[0][1:]
-        users = yield self.exec_query(
-            "SELECT username, encrypted_password, email FROM users WHERE username=%s",
-            local_part,
+            return None
+
+        local_part = username.split(":", 1)[0][1:] if ":" in username else username
+        loop = asyncio.get_event_loop()
+        user = await self.loop.run_until_complete(
+            self.get_user(
+                loop,
+                "SELECT username, encrypted_password, email FROM users WHERE username=%s",
+                local_part,
+            )
         )
         # user_id is @localpart:hs_bare. we only need the localpart.
         logger.info("Checking if user {} exists.".format(local_part))
@@ -90,10 +118,11 @@ class DiasporaAuthProvider:
             logger.info(
                 "User {} does not exist. Rejecting auth request".format(local_part)
             )
-            defer.returnValue(False)
-        user = users[0]
+            return None
+
         logger.debug("User {} exists. Checking password".format(local_part))
         # user exists, check if the password is correct.
+        user = users[0]
         encrypted_password = user[1]
         email = user[2]
         peppered_pass = "{}{}".format(password, self.config.pepper)
@@ -108,70 +137,8 @@ class DiasporaAuthProvider:
                     local_part
                 )
             )
-            defer.returnValue(False)
-        self.register_user(local_part, email)
-        logger.info("Confirming authentication request.")
-        defer.returnValue(True)
-
-    @defer.inlineCallbacks
-    def check_3pid_auth(self, medium, address, password):
-        logger.info(medium, address, password)
-        if medium != "email":
-            defer.returnValue(None)
-        logger.debug("Searching for email {} in diaspora db.".format(address))
-        users = yield self.exec_query(
-            "SELECT username FROM users WHERE email=%s", address
-        )
-        if not users:
-            defer.returnValue(None)
-        username = users[0][0]
-        logger.debug("Found username! {}".format(username))
-        logger.debug("Registering user {}".format(username))
-        self.register_user(username, address)
-        logger.debug("Registration complete")
-        defer.returnValue(username)
-
-    @defer.inlineCallbacks
-    def register_user(self, local_part, email):
-        if (yield self.account_handler.check_user_exists(local_part)):
-            yield self.sync_email(local_part, email)
-            defer.returnValue(local_part)
-        else:
-            user_id = yield self.account_handler.register_user(
-                localpart=local_part, emails=[email]
-            )
-            defer.returnValue(user_id)
-
-    @defer.inlineCallbacks
-    def sync_email(self, user_id, email):
-        logger.info("Syncing emails of {}".format(user_id))
-        email = email.lower()
-        store = self.account_handler._store  # Need access to datastore
-        threepids = yield store.user_get_threepids(user_id)
-        if not threepids:
-            logger.info("No 3pids found.")
-            yield self.add_email(user_id, email)
-        for threepid in threepids:
-            if not threepid["medium"] == "email":
-                logger.debug("Not an email: {}".format(str(threepid)))
-                pass
-            address = threepid["address"]
-            if address != email:
-                logger.info(
-                    "Existing 3pid doesn't match {} != {}. Deleting".format(
-                        address, email
-                    )
-                )
-                yield self.auth_handler.delete_threepid(user_id, "email", address)
-                yield self.add_email(user_id, email)
-                break
-        logger.info("Sync completed.")
-
-    @defer.inlineCallbacks
-    def add_email(self, user_id, email):
-        logger.info("Adding 3pid: {} for {}".format(email, user_id))
-        validated_at = self.account_handler._hs.get_clock().time_msec()
-        yield self.auth_handler.add_threepid(user_id, "email", email, validated_at)
+            return None
+        return (self.api.get_qualified_user_id(username), None)
 
     @staticmethod
     def parse_config(config):
-- 
GitLab