Skip to content
Snippets Groups Projects
Commit df1a9fda authored by Bady's avatar Bady
Browse files

Port module to the new interface

Synapse introduced new generic modules system in v1.37.0. For more details see:
- https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390
- https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html
- https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html

This commit only updates the code needed for login. Once it is tested
and found working, the code needed to register users which used to work
with the legacy module api also needs to be updated.
parent b9f8cd1d
No related merge requests found
...@@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License ...@@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from twisted.internet import defer, threads from collections.abc import Awaitable, Callable
from typing import Any
import asyncio
import synapse import synapse
from synapse import module_api
import bcrypt import bcrypt
import logging import logging
from pkg_resources import parse_version
__VERSION__ = "0.2.2" __VERSION__ = "0.2.2"
...@@ -34,53 +37,78 @@ logger = logging.getLogger(__name__) ...@@ -34,53 +37,78 @@ logger = logging.getLogger(__name__)
class DiasporaAuthProvider: class DiasporaAuthProvider:
__version__ = __VERSION__ __version__ = __VERSION__
def __init__(self, config, account_handler): def __init__(self, config: dict, api: module_api):
self.account_handler = account_handler
self.config = config self.config = config
self.auth_handler = self.account_handler._auth_handler self.api = api
if self.config.engine == "mysql": if self.config.engine == "mysql":
import pymysql import aiomysql
self.module = pymysql self.module = aiomysql
elif self.config.engine == "postgres": elif self.config.engine == "postgres":
import psycopg2 import aiopg
self.module = psycopg2 self.module = aiopg
@defer.inlineCallbacks api.register_password_auth_provider_callbacks(
def exec_query(self, query, *args): auth_checkers={
self.connection = self.module.connect( ("m.login.password", ("password",)): self.check_pass,
database=self.config.db_name, },
)
async def exec_query(
self,
loop: asyncio.AbstractEventLoop,
query: str,
*args: list[str],
) -> tuple[tuple[str]] | None:
pool = await self.module.create_pool(
db=self.config.db_name,
user=self.config.db_username, user=self.config.db_username,
password=self.config.db_password, password=self.config.db_password,
host=self.config.db_host, host=self.config.db_host,
port=self.config.db_port, port=self.config.db_port,
loop=loop,
) )
try: try:
with self.connection: async with pool.acquire() as connection:
with self.connection.cursor() as cursor: async with connection.cursor() as cursor:
yield threads.deferToThread( # Don't think this is needed, but w/e await cursor.execute(query, username)
cursor.execute, results = await cursor.fetchall()
query, return results
args,
)
results = yield threads.deferToThread(cursor.fetchall)
cursor.close()
defer.returnValue(results)
except self.module.Error as e: except self.module.Error as e:
logger.warning("Error during execution of query: {}: {}".format(query, e)) logger.warning("Error during execution of query: {}: {}".format(query, e))
defer.returnValue(None) return None
finally: finally:
self.connection.close() pool.close()
await pool.wait_closed()
@defer.inlineCallbacks
def check_password(self, user_id, password): async def check_pass(
self,
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> (
tuple[
str,
Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None,
]
| None
):
if login_type != "m.login.password":
return None
password = login_dict.get("password")
if not password: if not password:
defer.returnValue(False) return None
local_part = user_id.split(":", 1)[0][1:]
users = yield self.exec_query( local_part = username.split(":", 1)[0][1:] if ":" in username else username
"SELECT username, encrypted_password, email FROM users WHERE username=%s", loop = asyncio.get_event_loop()
local_part, user = await self.loop.run_until_complete(
self.get_user(
loop,
"SELECT username, encrypted_password, email FROM users WHERE username=%s",
local_part,
)
) )
# user_id is @localpart:hs_bare. we only need the localpart. # user_id is @localpart:hs_bare. we only need the localpart.
logger.info("Checking if user {} exists.".format(local_part)) logger.info("Checking if user {} exists.".format(local_part))
...@@ -90,10 +118,11 @@ class DiasporaAuthProvider: ...@@ -90,10 +118,11 @@ class DiasporaAuthProvider:
logger.info( logger.info(
"User {} does not exist. Rejecting auth request".format(local_part) "User {} does not exist. Rejecting auth request".format(local_part)
) )
defer.returnValue(False) return None
user = users[0]
logger.debug("User {} exists. Checking password".format(local_part)) logger.debug("User {} exists. Checking password".format(local_part))
# user exists, check if the password is correct. # user exists, check if the password is correct.
user = users[0]
encrypted_password = user[1] encrypted_password = user[1]
email = user[2] email = user[2]
peppered_pass = "{}{}".format(password, self.config.pepper) peppered_pass = "{}{}".format(password, self.config.pepper)
...@@ -108,70 +137,8 @@ class DiasporaAuthProvider: ...@@ -108,70 +137,8 @@ class DiasporaAuthProvider:
local_part local_part
) )
) )
defer.returnValue(False) return None
self.register_user(local_part, email) return (self.api.get_qualified_user_id(username), None)
logger.info("Confirming authentication request.")
defer.returnValue(True)
@defer.inlineCallbacks
def check_3pid_auth(self, medium, address, password):
logger.info(medium, address, password)
if medium != "email":
defer.returnValue(None)
logger.debug("Searching for email {} in diaspora db.".format(address))
users = yield self.exec_query(
"SELECT username FROM users WHERE email=%s", address
)
if not users:
defer.returnValue(None)
username = users[0][0]
logger.debug("Found username! {}".format(username))
logger.debug("Registering user {}".format(username))
self.register_user(username, address)
logger.debug("Registration complete")
defer.returnValue(username)
@defer.inlineCallbacks
def register_user(self, local_part, email):
if (yield self.account_handler.check_user_exists(local_part)):
yield self.sync_email(local_part, email)
defer.returnValue(local_part)
else:
user_id = yield self.account_handler.register_user(
localpart=local_part, emails=[email]
)
defer.returnValue(user_id)
@defer.inlineCallbacks
def sync_email(self, user_id, email):
logger.info("Syncing emails of {}".format(user_id))
email = email.lower()
store = self.account_handler._store # Need access to datastore
threepids = yield store.user_get_threepids(user_id)
if not threepids:
logger.info("No 3pids found.")
yield self.add_email(user_id, email)
for threepid in threepids:
if not threepid["medium"] == "email":
logger.debug("Not an email: {}".format(str(threepid)))
pass
address = threepid["address"]
if address != email:
logger.info(
"Existing 3pid doesn't match {} != {}. Deleting".format(
address, email
)
)
yield self.auth_handler.delete_threepid(user_id, "email", address)
yield self.add_email(user_id, email)
break
logger.info("Sync completed.")
@defer.inlineCallbacks
def add_email(self, user_id, email):
logger.info("Adding 3pid: {} for {}".format(email, user_id))
validated_at = self.account_handler._hs.get_clock().time_msec()
yield self.auth_handler.add_threepid(user_id, "email", email, validated_at)
@staticmethod @staticmethod
def parse_config(config): def parse_config(config):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment