Skip to content
Snippets Groups Projects
Commit 3e219f5b authored by Bady's avatar Bady
Browse files
parent ec85cbb2
Branches
1 merge request!1Port module to the new interface
...@@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License ...@@ -16,15 +16,18 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from twisted.internet import defer, threads from collections.abc import Awaitable, Callable
from typing import Any
import asyncio
import synapse import synapse
from synapse import module_api
import bcrypt import bcrypt
import logging import logging
from pkg_resources import parse_version
__VERSION__ = "0.2.2" __VERSION__ = "0.2.2"
...@@ -34,53 +37,78 @@ logger = logging.getLogger(__name__) ...@@ -34,53 +37,78 @@ logger = logging.getLogger(__name__)
class DiasporaAuthProvider: class DiasporaAuthProvider:
__version__ = __VERSION__ __version__ = __VERSION__
def __init__(self, config, account_handler): def __init__(self, config: dict, api: module_api):
self.account_handler = account_handler
self.config = config self.config = config
self.auth_handler = self.account_handler._auth_handler self.api = api
if self.config.engine == "mysql": if self.config.engine == "mysql":
import pymysql import aiomysql
self.module = pymysql self.module = aiomysql
elif self.config.engine == "postgres": elif self.config.engine == "postgres":
import psycopg2 import aiopg
self.module = psycopg2 self.module = aiopg
api.register_password_auth_provider_callbacks(
auth_checkers={
("m.login.password", ("password",)): self.check_pass,
},
)
@defer.inlineCallbacks async def exec_query(
def exec_query(self, query, *args): self,
self.connection = self.module.connect( loop: asyncio.AbstractEventLoop,
database=self.config.db_name, query: str,
*args: list[str],
) -> tuple[tuple[str]] | None:
pool = await self.module.create_pool(
db=self.config.db_name,
user=self.config.db_username, user=self.config.db_username,
password=self.config.db_password, password=self.config.db_password,
host=self.config.db_host, host=self.config.db_host,
port=self.config.db_port, port=self.config.db_port,
loop=loop,
) )
try: try:
with self.connection: async with pool.acquire() as connection:
with self.connection.cursor() as cursor: async with connection.cursor() as cursor:
yield threads.deferToThread( # Don't think this is needed, but w/e await cursor.execute(query, args)
cursor.execute, results = cursor.fetchall()
query, return results
args,
)
results = yield threads.deferToThread(cursor.fetchall)
cursor.close()
defer.returnValue(results)
except self.module.Error as e: except self.module.Error as e:
logger.warning("Error during execution of query: {}: {}".format(query, e)) logger.warning("Error during execution of query: {}: {}".format(query, e))
defer.returnValue(None) return None
finally: finally:
self.connection.close() pool.close()
await pool.wait_closed()
@defer.inlineCallbacks
def check_password(self, user_id, password): async def check_pass(
self,
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> (
tuple[
str,
Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None,
]
| None
):
if login_type != "m.login.password":
return None
password = login_dict.get("password")
if not password: if not password:
defer.returnValue(False) return None
local_part = user_id.split(":", 1)[0][1:]
users = yield self.exec_query( local_part = username.split(":", 1)[0][1:] if ":" in username else username
"SELECT username, encrypted_password, email FROM users WHERE username=%s", loop = asyncio.get_event_loop()
local_part, users = await loop.run_until_complete(
self.exec_query(
loop,
"SELECT username, encrypted_password, email FROM users WHERE username=%s",
local_part,
)
) )
# user_id is @localpart:hs_bare. we only need the localpart. # user_id is @localpart:hs_bare. we only need the localpart.
logger.info("Checking if user {} exists.".format(local_part)) logger.info("Checking if user {} exists.".format(local_part))
...@@ -90,10 +118,11 @@ class DiasporaAuthProvider: ...@@ -90,10 +118,11 @@ class DiasporaAuthProvider:
logger.info( logger.info(
"User {} does not exist. Rejecting auth request".format(local_part) "User {} does not exist. Rejecting auth request".format(local_part)
) )
defer.returnValue(False) return None
user = users[0]
logger.debug("User {} exists. Checking password".format(local_part)) logger.debug("User {} exists. Checking password".format(local_part))
# user exists, check if the password is correct. # user exists, check if the password is correct.
user = users[0]
encrypted_password = user[1] encrypted_password = user[1]
email = user[2] email = user[2]
peppered_pass = "{}{}".format(password, self.config.pepper) peppered_pass = "{}{}".format(password, self.config.pepper)
...@@ -108,9 +137,8 @@ class DiasporaAuthProvider: ...@@ -108,9 +137,8 @@ class DiasporaAuthProvider:
local_part local_part
) )
) )
defer.returnValue(False) return None
logger.info("Confirming authentication request.") return (self.api.get_qualified_user_id(username), None)
defer.returnValue(True)
@staticmethod @staticmethod
def parse_config(config): def parse_config(config):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment