diff --git a/diaspora_auth_provider.py b/diaspora_auth_provider.py index 6690a80028d97aa183620ba5cc5ec622230bb7e4..77660bca292e59f72b220dc03f87c10e3ccfcdd0 100644 --- a/diaspora_auth_provider.py +++ b/diaspora_auth_provider.py @@ -16,17 +16,17 @@ You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. """ -from twisted.internet import defer, threads +from typing import Any, Awaitable, Callable, Optional, Tuple + import synapse +from synapse import module_api import bcrypt import logging -from pkg_resources import parse_version - -__VERSION__ = "0.2.2" +__VERSION__ = "0.4.0" logger = logging.getLogger(__name__) @@ -34,10 +34,9 @@ logger = logging.getLogger(__name__) class DiasporaAuthProvider: __version__ = __VERSION__ - def __init__(self, config, account_handler): - self.account_handler = account_handler + def __init__(self, config: dict, api: module_api): + self.api = api self.config = config - self.auth_handler = self.account_handler._auth_handler if self.config.engine == "mysql": import pymysql @@ -47,8 +46,17 @@ class DiasporaAuthProvider: self.module = psycopg2 - @defer.inlineCallbacks - def exec_query(self, query, *args): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("m.login.password", ("password",)): self.check_password, + }, + ) + + async def exec_query( + self, + query: str, + *args: str, + ) -> Optional[Tuple[Tuple[Any]]]: self.connection = self.module.connect( database=self.config.db_name, user=self.config.db_username, @@ -59,26 +67,36 @@ class DiasporaAuthProvider: try: with self.connection: with self.connection.cursor() as cursor: - yield threads.deferToThread( # Don't think this is needed, but w/e - cursor.execute, - query, - args, - ) - results = yield threads.deferToThread(cursor.fetchall) + cursor.execute(query, args) + results = cursor.fetchall() cursor.close() - defer.returnValue(results) + return results except self.module.Error as e: logger.warning("Error during execution of query: {}: {}".format(query, e)) - defer.returnValue(None) + return None finally: self.connection.close() - @defer.inlineCallbacks - def check_password(self, user_id, password): + async def check_password( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "m.login.password": + return None + + password = login_dict.get("password") if not password: - defer.returnValue(False) - local_part = user_id.split(":", 1)[0][1:] - users = yield self.exec_query( + return None + + local_part = username.split(":", 1)[0][1:] if ":" in username else username + users = await self.exec_query( "SELECT username, encrypted_password, email FROM users WHERE username=%s", local_part, ) @@ -90,7 +108,8 @@ class DiasporaAuthProvider: logger.info( "User {} does not exist. Rejecting auth request".format(local_part) ) - defer.returnValue(False) + return None + user = users[0] logger.debug("User {} exists. Checking password".format(local_part)) # user exists, check if the password is correct. @@ -108,70 +127,8 @@ class DiasporaAuthProvider: local_part ) ) - defer.returnValue(False) - self.register_user(local_part, email) - logger.info("Confirming authentication request.") - defer.returnValue(True) - - @defer.inlineCallbacks - def check_3pid_auth(self, medium, address, password): - logger.info(medium, address, password) - if medium != "email": - defer.returnValue(None) - logger.debug("Searching for email {} in diaspora db.".format(address)) - users = yield self.exec_query( - "SELECT username FROM users WHERE email=%s", address - ) - if not users: - defer.returnValue(None) - username = users[0][0] - logger.debug("Found username! {}".format(username)) - logger.debug("Registering user {}".format(username)) - self.register_user(username, address) - logger.debug("Registration complete") - defer.returnValue(username) - - @defer.inlineCallbacks - def register_user(self, local_part, email): - if (yield self.account_handler.check_user_exists(local_part)): - yield self.sync_email(local_part, email) - defer.returnValue(local_part) - else: - user_id = yield self.account_handler.register_user( - localpart=local_part, emails=[email] - ) - defer.returnValue(user_id) - - @defer.inlineCallbacks - def sync_email(self, user_id, email): - logger.info("Syncing emails of {}".format(user_id)) - email = email.lower() - store = self.account_handler._store # Need access to datastore - threepids = yield store.user_get_threepids(user_id) - if not threepids: - logger.info("No 3pids found.") - yield self.add_email(user_id, email) - for threepid in threepids: - if not threepid["medium"] == "email": - logger.debug("Not an email: {}".format(str(threepid))) - pass - address = threepid["address"] - if address != email: - logger.info( - "Existing 3pid doesn't match {} != {}. Deleting".format( - address, email - ) - ) - yield self.auth_handler.delete_threepid(user_id, "email", address) - yield self.add_email(user_id, email) - break - logger.info("Sync completed.") - - @defer.inlineCallbacks - def add_email(self, user_id, email): - logger.info("Adding 3pid: {} for {}".format(email, user_id)) - validated_at = self.account_handler._hs.get_clock().time_msec() - yield self.auth_handler.add_threepid(user_id, "email", email, validated_at) + return None + return (self.api.get_qualified_user_id(username), None) @staticmethod def parse_config(config):