From 3da689730333bc46bc4665b337f949c56d936666 Mon Sep 17 00:00:00 2001
From: Bady <bady@disroot.org>
Date: Thu, 31 Oct 2024 14:20:18 +0000
Subject: [PATCH] Port module to the new interface

Synapse introduced new generic modules system in v1.37.0. For more details see:
- https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390
- https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html
- https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html
---
 diaspora_auth_provider.py | 150 +++++++++++---------------------------
 1 file changed, 44 insertions(+), 106 deletions(-)

diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py
index 6690a80..e72b62a 100644
--- a/diaspora_auth_provider.py
+++ b/diaspora_auth_provider.py
@@ -16,17 +16,17 @@ You should have received a copy of the GNU General Public License
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
 """
 
-from twisted.internet import defer, threads
+from typing import Any, Awaitable, Callable, Optional, Tuple
+
 import synapse
+from synapse import module_api
 
 import bcrypt
 
 import logging
 
-from pkg_resources import parse_version
-
 
-__VERSION__ = "0.2.2"
+__VERSION__ = "0.4.0"
 
 logger = logging.getLogger(__name__)
 
@@ -34,10 +34,9 @@ logger = logging.getLogger(__name__)
 class DiasporaAuthProvider:
     __version__ = __VERSION__
 
-    def __init__(self, config, account_handler):
-        self.account_handler = account_handler
+    def __init__(self, config: dict, api: module_api):
+        self.api = api
         self.config = config
-        self.auth_handler = self.account_handler._auth_handler
         if self.config.engine == "mysql":
             import pymysql
 
@@ -47,8 +46,17 @@ class DiasporaAuthProvider:
 
             self.module = psycopg2
 
-    @defer.inlineCallbacks
-    def exec_query(self, query, *args):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("m.login.password", ("password",)): self.check_password,
+            },
+        )
+
+    async def exec_query(
+        self,
+        query: str,
+        *args: str,
+    ) -> Optional[Tuple[Tuple[Any]]]:
         self.connection = self.module.connect(
             database=self.config.db_name,
             user=self.config.db_username,
@@ -59,26 +67,36 @@ class DiasporaAuthProvider:
         try:
             with self.connection:
                 with self.connection.cursor() as cursor:
-                    yield threads.deferToThread(  # Don't think this is needed, but w/e
-                        cursor.execute,
-                        query,
-                        args,
-                    )
-                    results = yield threads.deferToThread(cursor.fetchall)
+                    cursor.execute(query, args)
+                    results = cursor.fetchall()
                     cursor.close()
-            defer.returnValue(results)
+                    return results
         except self.module.Error as e:
             logger.warning("Error during execution of query: {}: {}".format(query, e))
-            defer.returnValue(None)
+            return None
         finally:
             self.connection.close()
 
-    @defer.inlineCallbacks
-    def check_password(self, user_id, password):
+    async def check_password(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> Optional[
+        Tuple[
+            str,
+            Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
+        ]
+    ]:
+        if login_type != "m.login.password":
+            return None
+
+        password = login_dict.get("password")
         if not password:
-            defer.returnValue(False)
-        local_part = user_id.split(":", 1)[0][1:]
-        users = yield self.exec_query(
+            return None
+
+        local_part = username.split(":", 1)[0][1:] if ":" in username else username
+        users = await self.exec_query(
             "SELECT username, encrypted_password, email FROM users WHERE username=%s",
             local_part,
         )
@@ -90,7 +108,8 @@ class DiasporaAuthProvider:
             logger.info(
                 "User {} does not exist. Rejecting auth request".format(local_part)
             )
-            defer.returnValue(False)
+            return None
+
         user = users[0]
         logger.debug("User {} exists. Checking password".format(local_part))
         # user exists, check if the password is correct.
@@ -108,86 +127,5 @@ class DiasporaAuthProvider:
                     local_part
                 )
             )
-            defer.returnValue(False)
-        self.register_user(local_part, email)
-        logger.info("Confirming authentication request.")
-        defer.returnValue(True)
-
-    @defer.inlineCallbacks
-    def check_3pid_auth(self, medium, address, password):
-        logger.info(medium, address, password)
-        if medium != "email":
-            defer.returnValue(None)
-        logger.debug("Searching for email {} in diaspora db.".format(address))
-        users = yield self.exec_query(
-            "SELECT username FROM users WHERE email=%s", address
-        )
-        if not users:
-            defer.returnValue(None)
-        username = users[0][0]
-        logger.debug("Found username! {}".format(username))
-        logger.debug("Registering user {}".format(username))
-        self.register_user(username, address)
-        logger.debug("Registration complete")
-        defer.returnValue(username)
-
-    @defer.inlineCallbacks
-    def register_user(self, local_part, email):
-        if (yield self.account_handler.check_user_exists(local_part)):
-            yield self.sync_email(local_part, email)
-            defer.returnValue(local_part)
-        else:
-            user_id = yield self.account_handler.register_user(
-                localpart=local_part, emails=[email]
-            )
-            defer.returnValue(user_id)
-
-    @defer.inlineCallbacks
-    def sync_email(self, user_id, email):
-        logger.info("Syncing emails of {}".format(user_id))
-        email = email.lower()
-        store = self.account_handler._store  # Need access to datastore
-        threepids = yield store.user_get_threepids(user_id)
-        if not threepids:
-            logger.info("No 3pids found.")
-            yield self.add_email(user_id, email)
-        for threepid in threepids:
-            if not threepid["medium"] == "email":
-                logger.debug("Not an email: {}".format(str(threepid)))
-                pass
-            address = threepid["address"]
-            if address != email:
-                logger.info(
-                    "Existing 3pid doesn't match {} != {}. Deleting".format(
-                        address, email
-                    )
-                )
-                yield self.auth_handler.delete_threepid(user_id, "email", address)
-                yield self.add_email(user_id, email)
-                break
-        logger.info("Sync completed.")
-
-    @defer.inlineCallbacks
-    def add_email(self, user_id, email):
-        logger.info("Adding 3pid: {} for {}".format(email, user_id))
-        validated_at = self.account_handler._hs.get_clock().time_msec()
-        yield self.auth_handler.add_threepid(user_id, "email", email, validated_at)
-
-    @staticmethod
-    def parse_config(config):
-        class _Conf:
-            pass
-
-        Conf = _Conf()
-        Conf.engine = config["database"]["engine"]
-        Conf.db_name = (
-            "diaspora_production"
-            if not config["database"]["name"]
-            else config["database"]["name"]
-        )
-        Conf.db_host = config["database"]["host"]
-        Conf.db_port = config["database"]["port"]
-        Conf.db_username = config["database"]["username"]
-        Conf.db_password = config["database"]["password"]
-        Conf.pepper = config["pepper"]
-        return Conf
+            return None
+        return (self.api.get_qualified_user_id(username), None)
-- 
GitLab