diff --git a/src/domain/auth.py b/src/domain/auth.py new file mode 100644 index 0000000..c0f0c2a --- /dev/null +++ b/src/domain/auth.py @@ -0,0 +1,12 @@ +from enum import StrEnum + +from pydantic import BaseModel + + +class OAuthProvider(StrEnum): + TWITCH = "twitch" + + +class OAuthData(BaseModel): + id: str + email: str | None diff --git a/src/domain/users.py b/src/domain/users.py new file mode 100644 index 0000000..fbd448e --- /dev/null +++ b/src/domain/users.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from domain.auth import OAuthProvider, OAuthData + + +class User(BaseModel): + id: str + + oauths: dict[OAuthProvider, OAuthData] + + is_admin: bool + + +class CreateUser(BaseModel): + oauths: dict[OAuthProvider, OAuthData] + + is_admin: bool = False diff --git a/src/modules/web_app/serializers/auth.py b/src/modules/web_app/serializers/auth.py new file mode 100644 index 0000000..468454a --- /dev/null +++ b/src/modules/web_app/serializers/auth.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class GetAuthorizationUrlResponse(BaseModel): + authorization_url: str diff --git a/src/modules/web_app/services/oauth/authorization_url_getter.py b/src/modules/web_app/services/oauth/authorization_url_getter.py new file mode 100644 index 0000000..c5d2a3f --- /dev/null +++ b/src/modules/web_app/services/oauth/authorization_url_getter.py @@ -0,0 +1,18 @@ +from core.config import config + +from domain.auth import OAuthProvider + +from .providers import get_client + + +REDIRECT_URI_TEMPLATE = f"https://{config.WEB_APP_HOST}/" + "auth/callback/{service}/" + + +async def get_authorization_url(provider: OAuthProvider) -> str: + client = get_client(provider) + + return await client.get_authorization_url( + redirect_uri=REDIRECT_URI_TEMPLATE.format( + service=provider.value + ), + ) diff --git a/src/modules/web_app/services/oauth/process_callback.py b/src/modules/web_app/services/oauth/process_callback.py new file mode 100644 index 0000000..3f21cc8 --- /dev/null +++ b/src/modules/web_app/services/oauth/process_callback.py @@ -0,0 +1,16 @@ +from domain.auth import OAuthProvider + +from .providers import get_client +from .authorization_url_getter import REDIRECT_URI_TEMPLATE + + +async def process_callback(provider: OAuthProvider, code: str) -> tuple[str, str | None]: + client = get_client(provider) + token = await client.get_access_token( + code, + redirect_uri=REDIRECT_URI_TEMPLATE.format(service=provider.value), + ) + + user_data = await client.get_id_email(token["access_token"]) + + return user_data diff --git a/src/modules/web_app/services/oauth/providers/__init__.py b/src/modules/web_app/services/oauth/providers/__init__.py new file mode 100644 index 0000000..f5520a5 --- /dev/null +++ b/src/modules/web_app/services/oauth/providers/__init__.py @@ -0,0 +1,10 @@ +from .enums import OAuthProvider +from .twitch import twitch_oauth_client +from .getter import get_client + + +__all__ = [ + "OAuthProvider", + "twitch_oauth_client", + "get_client" +] diff --git a/src/modules/web_app/services/oauth/providers/getter.py b/src/modules/web_app/services/oauth/providers/getter.py new file mode 100644 index 0000000..0273c66 --- /dev/null +++ b/src/modules/web_app/services/oauth/providers/getter.py @@ -0,0 +1,11 @@ +from httpx_oauth.oauth2 import OAuth2 + +from domain.auth import OAuthProvider +from .twitch import twitch_oauth_client + + +def get_client(provider: OAuthProvider) -> OAuth2: + if provider == OAuthProvider.TWITCH: + return twitch_oauth_client + else: + raise NotImplementedError("Provider is not implemented") diff --git a/src/modules/web_app/services/oauth/providers/twitch.py b/src/modules/web_app/services/oauth/providers/twitch.py new file mode 100644 index 0000000..43bd9ec --- /dev/null +++ b/src/modules/web_app/services/oauth/providers/twitch.py @@ -0,0 +1,34 @@ +from twitchAPI.twitch import Twitch, AuthScope +from twitchAPI.helper import first + +from httpx_oauth.oauth2 import OAuth2 + +from core.config import config + + +class TwithOAuth2(OAuth2): + async def get_id_email(self, token: str): + twitch_client = Twitch(config.TWITCH_CLIENT_ID, config.TWITCH_CLIENT_SECRET) + twitch_client.auto_refresh_auth = False + + await twitch_client.set_user_authentication( + token, + [AuthScope.USER_READ_EMAIL], + validate=True + ) + + me = await first(twitch_client.get_users()) + + if me is None: + raise Exception("Failed to get user data") + + return me.id, me.email + + +twitch_oauth_client = TwithOAuth2( + config.TWITCH_CLIENT_ID, + config.TWITCH_CLIENT_SECRET, + "https://id.twitch.tv/oauth2/authorize", + "https://id.twitch.tv/oauth2/token", + base_scopes=[AuthScope.USER_READ_EMAIL.value], +) diff --git a/src/modules/web_app/views/auth.py b/src/modules/web_app/views/auth.py index d88440e..7b43520 100644 --- a/src/modules/web_app/views/auth.py +++ b/src/modules/web_app/views/auth.py @@ -1,68 +1,32 @@ from fastapi import APIRouter -from twitchAPI.twitch import Twitch, AuthScope -from twitchAPI.helper import first -from httpx_oauth.oauth2 import OAuth2 - -from core.config import config +from domain.auth import OAuthProvider, OAuthData +from domain.users import CreateUser +from modules.web_app.services.oauth.process_callback import process_callback +from modules.web_app.services.oauth.authorization_url_getter import get_authorization_url as gen_auth_link +from modules.web_app.serializers.auth import GetAuthorizationUrlResponse +from repositories.users import UserRepository auth_router = APIRouter(prefix="/auth", tags=["auth"]) -class TwithOAuth2(OAuth2): - async def get_id_email(self, token: str): - twitch_client = Twitch(config.TWITCH_CLIENT_ID, config.TWITCH_CLIENT_SECRET) - twitch_client.auto_refresh_auth = False +@auth_router.get("/get_authorization_url/{provider}/") +async def get_authorization_url(provider: OAuthProvider) -> GetAuthorizationUrlResponse: + link = await gen_auth_link(provider) - await twitch_client.set_user_authentication( - token, - [AuthScope.USER_READ_EMAIL], - validate=True + return GetAuthorizationUrlResponse(authorization_url=link) + + +@auth_router.get("/callback/{provider}/") +async def callback(provider: OAuthProvider, code: str): + user_data = await process_callback(provider, code) + + user = await UserRepository.get_or_create_user( + CreateUser( + oauths={provider: OAuthData(id=user_data[0], email=user_data[1])}, + is_admin=False, ) + ) - me = await first(twitch_client.get_users()) - - if me is None: - raise Exception("Failed to get user data") - - return me.id, me.email - - -twitch_oauth = TwithOAuth2( - config.TWITCH_CLIENT_ID, - config.TWITCH_CLIENT_SECRET, - "https://id.twitch.tv/oauth2/authorize", - "https://id.twitch.tv/oauth2/token", - base_scopes=[AuthScope.USER_READ_EMAIL.value], -) - - -REDIRECT_URI_TEMPLATE = f"https://{config.WEB_APP_HOST}/" + "auth/callback/{service}/" - - -@auth_router.get("/get_authorization_url/{service}/") -async def get_authorization_url(service: str): - link = None - - if service == "twitch": - link = await twitch_oauth.get_authorization_url( - redirect_uri=REDIRECT_URI_TEMPLATE.format(service="twitch"), - ) - - return {"link": link} - - -@auth_router.get("/callback/{service}/") -async def callback(service: str, code: str): - user_data = None - - if service == "twitch": - token = await twitch_oauth.get_access_token( - code, - redirect_uri=REDIRECT_URI_TEMPLATE.format(service="twitch"), - ) - - user_data = await twitch_oauth.get_id_email(token["access_token"]) - - return {"user_data": user_data} + return {"user": user.model_dump()} diff --git a/src/repositories/base.py b/src/repositories/base.py new file mode 100644 index 0000000..5067e61 --- /dev/null +++ b/src/repositories/base.py @@ -0,0 +1,18 @@ +import abc + +from contextlib import asynccontextmanager + +from core.mongo import mongo_manager + + +class BaseRepository(abc.ABC): + COLLECTION_NAME: str + + @asynccontextmanager + @classmethod + async def connect(cls): + async with mongo_manager.connect() as client: + db = client.get_default_database() + collection = db[cls.COLLECTION_NAME] + + yield collection diff --git a/src/repositories/streamers.py b/src/repositories/streamers.py index ed00c72..3cbace1 100644 --- a/src/repositories/streamers.py +++ b/src/repositories/streamers.py @@ -1,17 +1,14 @@ from domain.streamers import StreamerConfig -from core.mongo import mongo_manager +from .base import BaseRepository -class StreamerConfigRepository: +class StreamerConfigRepository(BaseRepository): COLLECTION_NAME = "streamers" @classmethod async def get_by_twitch_id(cls, twitch_id: int) -> StreamerConfig: - async with mongo_manager.connect() as client: - db = client.get_default_database() - collection = db[cls.COLLECTION_NAME] - + async with cls.connect() as collection: doc = await collection.find_one({"twitch.id": twitch_id}) if doc is None: raise ValueError(f"Streamer with twitch id {twitch_id} not found") @@ -34,10 +31,7 @@ class StreamerConfigRepository: "integrations.discord.games_list.channel_id" ] = integration_discord_games_list_channel_id - async with mongo_manager.connect() as client: - db = client.get_default_database() - collection = db[cls.COLLECTION_NAME] - + async with cls.connect() as collection: doc = await collection.find_one(filters) if doc is None: return None @@ -46,9 +40,6 @@ class StreamerConfigRepository: @classmethod async def all(cls) -> list[StreamerConfig]: - async with mongo_manager.connect() as client: - db = client.get_default_database() - collection = db[cls.COLLECTION_NAME] - + async with cls.connect() as collection: cursor = await collection.find() return [StreamerConfig(**doc) async for doc in cursor] diff --git a/src/repositories/users.py b/src/repositories/users.py new file mode 100644 index 0000000..d6df911 --- /dev/null +++ b/src/repositories/users.py @@ -0,0 +1,29 @@ +from domain.users import CreateUser, User + +from .base import BaseRepository + + +class UserRepository(BaseRepository): + COLLECTION_NAME = "users" + + @classmethod + async def get_or_create_user(cls, newUser: CreateUser) -> User: + filter_data = {} + + for provider, data in newUser.oauths.items(): + filter_data[f"oauths.{provider}.id"] = data.id + + async with cls.connect() as collection: + await collection.update_one( + filter_data, + { + "$setOnInsert": { + **newUser.model_dump(), + } + }, + upsert=True, + ) + + user = await collection.find_one(filter_data) + + return User(**user)