From 0cbb3fa1d67ca1f620e802d7eedbc20945f8ba65 Mon Sep 17 00:00:00 2001
From: Bady <bady@disroot.org>
Date: Thu, 31 Oct 2024 14:20:18 +0000
Subject: [PATCH] Port module to the new interface

Synapse introduced new generic modules system in v1.37.0. For more details see:
- https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390
- https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html
- https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html

This commit only updates the code needed for login. Once it is tested
and found working, the code needed to register users which used to work
with the legacy module api also needs to be updated.
---
 diaspora_auth_provider.py | 150 +++++++++++---------------------------
 1 file changed, 44 insertions(+), 106 deletions(-)

diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py
index 6690a80..e72b62a 100644
--- a/diaspora_auth_provider.py
+++ b/diaspora_auth_provider.py
@@ -16,17 +16,17 @@ You should have received a copy of the GNU General Public License
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
 """
 
-from twisted.internet import defer, threads
+from typing import Any, Awaitable, Callable, Optional, Tuple
+
 import synapse
+from synapse import module_api
 
 import bcrypt
 
 import logging
 
-from pkg_resources import parse_version
-
 
-__VERSION__ = "0.2.2"
+__VERSION__ = "0.4.0"
 
 logger = logging.getLogger(__name__)
 
@@ -34,10 +34,9 @@ logger = logging.getLogger(__name__)
 class DiasporaAuthProvider:
     __version__ = __VERSION__
 
-    def __init__(self, config, account_handler):
-        self.account_handler = account_handler
+    def __init__(self, config: dict, api: module_api):
+        self.api = api
         self.config = config
-        self.auth_handler = self.account_handler._auth_handler
         if self.config.engine == "mysql":
             import pymysql
 
@@ -47,8 +46,17 @@ class DiasporaAuthProvider:
 
             self.module = psycopg2
 
-    @defer.inlineCallbacks
-    def exec_query(self, query, *args):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("m.login.password", ("password",)): self.check_password,
+            },
+        )
+
+    async def exec_query(
+        self,
+        query: str,
+        *args: str,
+    ) -> Optional[Tuple[Tuple[Any]]]:
         self.connection = self.module.connect(
             database=self.config.db_name,
             user=self.config.db_username,
@@ -59,26 +67,36 @@ class DiasporaAuthProvider:
         try:
             with self.connection:
                 with self.connection.cursor() as cursor:
-                    yield threads.deferToThread(  # Don't think this is needed, but w/e
-                        cursor.execute,
-                        query,
-                        args,
-                    )
-                    results = yield threads.deferToThread(cursor.fetchall)
+                    cursor.execute(query, args)
+                    results = cursor.fetchall()
                     cursor.close()
-            defer.returnValue(results)
+                    return results
         except self.module.Error as e:
             logger.warning("Error during execution of query: {}: {}".format(query, e))
-            defer.returnValue(None)
+            return None
         finally:
             self.connection.close()
 
-    @defer.inlineCallbacks
-    def check_password(self, user_id, password):
+    async def check_password(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> Optional[
+        Tuple[
+            str,
+            Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
+        ]
+    ]:
+        if login_type != "m.login.password":
+            return None
+
+        password = login_dict.get("password")
         if not password:
-            defer.returnValue(False)
-        local_part = user_id.split(":", 1)[0][1:]
-        users = yield self.exec_query(
+            return None
+
+        local_part = username.split(":", 1)[0][1:] if ":" in username else username
+        users = await self.exec_query(
             "SELECT username, encrypted_password, email FROM users WHERE username=%s",
             local_part,
         )
@@ -90,7 +108,8 @@ class DiasporaAuthProvider:
             logger.info(
                 "User {} does not exist. Rejecting auth request".format(local_part)
             )
-            defer.returnValue(False)
+            return None
+
         user = users[0]
         logger.debug("User {} exists. Checking password".format(local_part))
         # user exists, check if the password is correct.
@@ -108,86 +127,5 @@ class DiasporaAuthProvider:
                     local_part
                 )
             )
-            defer.returnValue(False)
-        self.register_user(local_part, email)
-        logger.info("Confirming authentication request.")
-        defer.returnValue(True)
-
-    @defer.inlineCallbacks
-    def check_3pid_auth(self, medium, address, password):
-        logger.info(medium, address, password)
-        if medium != "email":
-            defer.returnValue(None)
-        logger.debug("Searching for email {} in diaspora db.".format(address))
-        users = yield self.exec_query(
-            "SELECT username FROM users WHERE email=%s", address
-        )
-        if not users:
-            defer.returnValue(None)
-        username = users[0][0]
-        logger.debug("Found username! {}".format(username))
-        logger.debug("Registering user {}".format(username))
-        self.register_user(username, address)
-        logger.debug("Registration complete")
-        defer.returnValue(username)
-
-    @defer.inlineCallbacks
-    def register_user(self, local_part, email):
-        if (yield self.account_handler.check_user_exists(local_part)):
-            yield self.sync_email(local_part, email)
-            defer.returnValue(local_part)
-        else:
-            user_id = yield self.account_handler.register_user(
-                localpart=local_part, emails=[email]
-            )
-            defer.returnValue(user_id)
-
-    @defer.inlineCallbacks
-    def sync_email(self, user_id, email):
-        logger.info("Syncing emails of {}".format(user_id))
-        email = email.lower()
-        store = self.account_handler._store  # Need access to datastore
-        threepids = yield store.user_get_threepids(user_id)
-        if not threepids:
-            logger.info("No 3pids found.")
-            yield self.add_email(user_id, email)
-        for threepid in threepids:
-            if not threepid["medium"] == "email":
-                logger.debug("Not an email: {}".format(str(threepid)))
-                pass
-            address = threepid["address"]
-            if address != email:
-                logger.info(
-                    "Existing 3pid doesn't match {} != {}. Deleting".format(
-                        address, email
-                    )
-                )
-                yield self.auth_handler.delete_threepid(user_id, "email", address)
-                yield self.add_email(user_id, email)
-                break
-        logger.info("Sync completed.")
-
-    @defer.inlineCallbacks
-    def add_email(self, user_id, email):
-        logger.info("Adding 3pid: {} for {}".format(email, user_id))
-        validated_at = self.account_handler._hs.get_clock().time_msec()
-        yield self.auth_handler.add_threepid(user_id, "email", email, validated_at)
-
-    @staticmethod
-    def parse_config(config):
-        class _Conf:
-            pass
-
-        Conf = _Conf()
-        Conf.engine = config["database"]["engine"]
-        Conf.db_name = (
-            "diaspora_production"
-            if not config["database"]["name"]
-            else config["database"]["name"]
-        )
-        Conf.db_host = config["database"]["host"]
-        Conf.db_port = config["database"]["port"]
-        Conf.db_username = config["database"]["username"]
-        Conf.db_password = config["database"]["password"]
-        Conf.pepper = config["pepper"]
-        return Conf
+            return None
+        return (self.api.get_qualified_user_id(username), None)
-- 
GitLab