Skip to content
Snippets Groups Projects
Commit 1e076af6 authored by Bady's avatar Bady
Browse files

Port module to the new interface

Synapse introduced new generic modules system in v1.37.0. For more details see:
- https://element-hq.github.io/synapse/latest/upgrade.html#upgrading-to-v1390
- https://element-hq.github.io/synapse/latest/modules/porting_legacy_module.html
- https://element-hq.github.io/synapse/latest/modules/password_auth_provider_callbacks.html

This commit only updates the code needed for login. Once it is tested
and found working, the code needed to register users which used to work
with the legacy module api also needs to be updated.
parent b9f8cd1d
Branches
Tags v0.1
No related merge requests found
......@@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from twisted.internet import defer, threads
from collections.abc import Awaitable, Callable
from typing import Any
import asyncio
import synapse
from synapse import module_api
import bcrypt
import logging
from pkg_resources import parse_version
__VERSION__ = "0.2.2"
......@@ -34,53 +37,78 @@ logger = logging.getLogger(__name__)
class DiasporaAuthProvider:
__version__ = __VERSION__
def __init__(self, config, account_handler):
self.account_handler = account_handler
def __init__(self, config: dict, api: module_api):
self.config = config
self.auth_handler = self.account_handler._auth_handler
self.api = api
if self.config.engine == "mysql":
import pymysql
import aiomysql
self.module = pymysql
self.module = aiomysql
elif self.config.engine == "postgres":
import psycopg2
import aiopg
self.module = psycopg2
self.module = aiopg
@defer.inlineCallbacks
def exec_query(self, query, *args):
self.connection = self.module.connect(
database=self.config.db_name,
api.register_password_auth_provider_callbacks(
auth_checkers={
("m.login.password", ("password",)): self.check_pass,
},
)
async def exec_query(
self,
loop: asyncio.AbstractEventLoop,
query: str,
*args: list[str],
) -> tuple[tuple[str]] | None:
pool = await self.module.create_pool(
db=self.config.db_name,
user=self.config.db_username,
password=self.config.db_password,
host=self.config.db_host,
port=self.config.db_port,
loop=loop,
)
try:
with self.connection:
with self.connection.cursor() as cursor:
yield threads.deferToThread( # Don't think this is needed, but w/e
cursor.execute,
query,
args,
)
results = yield threads.deferToThread(cursor.fetchall)
cursor.close()
defer.returnValue(results)
async with pool.acquire() as connection:
async with connection.cursor() as cursor:
await cursor.execute(query, args)
results = cursor.fetchall()
return results
except self.module.Error as e:
logger.warning("Error during execution of query: {}: {}".format(query, e))
defer.returnValue(None)
return None
finally:
self.connection.close()
@defer.inlineCallbacks
def check_password(self, user_id, password):
pool.close()
await pool.wait_closed()
async def check_pass(
self,
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> (
tuple[
str,
Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None,
]
| None
):
if login_type != "m.login.password":
return None
password = login_dict.get("password")
if not password:
defer.returnValue(False)
local_part = user_id.split(":", 1)[0][1:]
users = yield self.exec_query(
"SELECT username, encrypted_password, email FROM users WHERE username=%s",
local_part,
return None
local_part = username.split(":", 1)[0][1:] if ":" in username else username
loop = asyncio.get_event_loop()
users = await loop.run_until_complete(
self.exec_query(
loop,
"SELECT username, encrypted_password, email FROM users WHERE username=%s",
local_part,
)
)
# user_id is @localpart:hs_bare. we only need the localpart.
logger.info("Checking if user {} exists.".format(local_part))
......@@ -90,10 +118,11 @@ class DiasporaAuthProvider:
logger.info(
"User {} does not exist. Rejecting auth request".format(local_part)
)
defer.returnValue(False)
user = users[0]
return None
logger.debug("User {} exists. Checking password".format(local_part))
# user exists, check if the password is correct.
user = users[0]
encrypted_password = user[1]
email = user[2]
peppered_pass = "{}{}".format(password, self.config.pepper)
......@@ -108,70 +137,8 @@ class DiasporaAuthProvider:
local_part
)
)
defer.returnValue(False)
self.register_user(local_part, email)
logger.info("Confirming authentication request.")
defer.returnValue(True)
@defer.inlineCallbacks
def check_3pid_auth(self, medium, address, password):
logger.info(medium, address, password)
if medium != "email":
defer.returnValue(None)
logger.debug("Searching for email {} in diaspora db.".format(address))
users = yield self.exec_query(
"SELECT username FROM users WHERE email=%s", address
)
if not users:
defer.returnValue(None)
username = users[0][0]
logger.debug("Found username! {}".format(username))
logger.debug("Registering user {}".format(username))
self.register_user(username, address)
logger.debug("Registration complete")
defer.returnValue(username)
@defer.inlineCallbacks
def register_user(self, local_part, email):
if (yield self.account_handler.check_user_exists(local_part)):
yield self.sync_email(local_part, email)
defer.returnValue(local_part)
else:
user_id = yield self.account_handler.register_user(
localpart=local_part, emails=[email]
)
defer.returnValue(user_id)
@defer.inlineCallbacks
def sync_email(self, user_id, email):
logger.info("Syncing emails of {}".format(user_id))
email = email.lower()
store = self.account_handler._store # Need access to datastore
threepids = yield store.user_get_threepids(user_id)
if not threepids:
logger.info("No 3pids found.")
yield self.add_email(user_id, email)
for threepid in threepids:
if not threepid["medium"] == "email":
logger.debug("Not an email: {}".format(str(threepid)))
pass
address = threepid["address"]
if address != email:
logger.info(
"Existing 3pid doesn't match {} != {}. Deleting".format(
address, email
)
)
yield self.auth_handler.delete_threepid(user_id, "email", address)
yield self.add_email(user_id, email)
break
logger.info("Sync completed.")
@defer.inlineCallbacks
def add_email(self, user_id, email):
logger.info("Adding 3pid: {} for {}".format(email, user_id))
validated_at = self.account_handler._hs.get_clock().time_msec()
yield self.auth_handler.add_threepid(user_id, "email", email, validated_at)
return None
return (self.api.get_qualified_user_id(username), None)
@staticmethod
def parse_config(config):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment