From 3e219f5b59a4a04e4246d32753b95c8a5ee0ce30 Mon Sep 17 00:00:00 2001
From: Bady <bady@disroot.org>
Date: Fri, 1 Nov 2024 16:32:12 +0000
Subject: [PATCH] Port module to the new interface

Synapse introduced new generic modules system in v1.37.0. For more details see:
- https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390
- https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html
- https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html
---
 diaspora_auth_provider.py | 106 ++++++++++++++++++++++++--------------
 1 file changed, 67 insertions(+), 39 deletions(-)

diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py
index d845d3b..0bf5312 100644
--- a/diaspora_auth_provider.py
+++ b/diaspora_auth_provider.py
@@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
 """
 
-from twisted.internet import defer, threads
+from collections.abc import Awaitable, Callable
+from typing import Any
+
+import asyncio
+
 import synapse
+from synapse import module_api
 
 import bcrypt
 
 import logging
 
-from pkg_resources import parse_version
-
 
 __VERSION__ = "0.2.2"
 
@@ -34,53 +37,78 @@ logger = logging.getLogger(__name__)
 class DiasporaAuthProvider:
     __version__ = __VERSION__
 
-    def __init__(self, config, account_handler):
-        self.account_handler = account_handler
+    def __init__(self, config: dict, api: module_api):
         self.config = config
-        self.auth_handler = self.account_handler._auth_handler
+        self.api = api
         if self.config.engine == "mysql":
-            import pymysql
+            import aiomysql
 
-            self.module = pymysql
+            self.module = aiomysql
         elif self.config.engine == "postgres":
-            import psycopg2
+            import aiopg
 
-            self.module = psycopg2
+            self.module = aiopg
+
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("m.login.password", ("password",)): self.check_pass,
+            },
+        )
 
-    @defer.inlineCallbacks
-    def exec_query(self, query, *args):
-        self.connection = self.module.connect(
-            database=self.config.db_name,
+    async def exec_query(
+        self,
+        loop: asyncio.AbstractEventLoop,
+        query: str,
+        *args: list[str],
+    ) -> tuple[tuple[str]] | None:
+        pool = await self.module.create_pool(
+            db=self.config.db_name,
             user=self.config.db_username,
             password=self.config.db_password,
             host=self.config.db_host,
             port=self.config.db_port,
+            loop=loop,
         )
         try:
-            with self.connection:
-                with self.connection.cursor() as cursor:
-                    yield threads.deferToThread(  # Don't think this is needed, but w/e
-                        cursor.execute,
-                        query,
-                        args,
-                    )
-                    results = yield threads.deferToThread(cursor.fetchall)
-                    cursor.close()
-            defer.returnValue(results)
+            async with pool.acquire() as connection:
+                async with connection.cursor() as cursor:
+                    await cursor.execute(query, args)
+                    results = cursor.fetchall()
+            return results
         except self.module.Error as e:
             logger.warning("Error during execution of query: {}: {}".format(query, e))
-            defer.returnValue(None)
+            return None
         finally:
-            self.connection.close()
-
-    @defer.inlineCallbacks
-    def check_password(self, user_id, password):
+            pool.close()
+            await pool.wait_closed()
+
+    async def check_pass(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> (
+        tuple[
+            str,
+            Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None,
+        ]
+        | None
+    ):
+        if login_type != "m.login.password":
+            return None
+
+        password = login_dict.get("password")
         if not password:
-            defer.returnValue(False)
-        local_part = user_id.split(":", 1)[0][1:]
-        users = yield self.exec_query(
-            "SELECT username, encrypted_password, email FROM users WHERE username=%s",
-            local_part,
+            return None
+
+        local_part = username.split(":", 1)[0][1:] if ":" in username else username
+        loop = asyncio.get_event_loop()
+        users = await loop.run_until_complete(
+            self.exec_query(
+                loop,
+                "SELECT username, encrypted_password, email FROM users WHERE username=%s",
+                local_part,
+            )
         )
         # user_id is @localpart:hs_bare. we only need the localpart.
         logger.info("Checking if user {} exists.".format(local_part))
@@ -90,10 +118,11 @@ class DiasporaAuthProvider:
             logger.info(
                 "User {} does not exist. Rejecting auth request".format(local_part)
             )
-            defer.returnValue(False)
-        user = users[0]
+            return None
+
         logger.debug("User {} exists. Checking password".format(local_part))
         # user exists, check if the password is correct.
+        user = users[0]
         encrypted_password = user[1]
         email = user[2]
         peppered_pass = "{}{}".format(password, self.config.pepper)
@@ -108,9 +137,8 @@ class DiasporaAuthProvider:
                     local_part
                 )
             )
-            defer.returnValue(False)
-        logger.info("Confirming authentication request.")
-        defer.returnValue(True)
+            return None
+        return (self.api.get_qualified_user_id(username), None)
 
     @staticmethod
     def parse_config(config):
-- 
GitLab