mirror of
https://github.com/kurbezz/discord-bot.git
synced 2025-12-06 07:05:36 +01:00
Update
This commit is contained in:
12
src/domain/auth.py
Normal file
12
src/domain/auth.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OAuthProvider(StrEnum):
|
||||
TWITCH = "twitch"
|
||||
|
||||
|
||||
class OAuthData(BaseModel):
|
||||
id: str
|
||||
email: str | None
|
||||
17
src/domain/users.py
Normal file
17
src/domain/users.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from domain.auth import OAuthProvider, OAuthData
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
|
||||
oauths: dict[OAuthProvider, OAuthData]
|
||||
|
||||
is_admin: bool
|
||||
|
||||
|
||||
class CreateUser(BaseModel):
|
||||
oauths: dict[OAuthProvider, OAuthData]
|
||||
|
||||
is_admin: bool = False
|
||||
5
src/modules/web_app/serializers/auth.py
Normal file
5
src/modules/web_app/serializers/auth.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GetAuthorizationUrlResponse(BaseModel):
|
||||
authorization_url: str
|
||||
@@ -0,0 +1,18 @@
|
||||
from core.config import config
|
||||
|
||||
from domain.auth import OAuthProvider
|
||||
|
||||
from .providers import get_client
|
||||
|
||||
|
||||
REDIRECT_URI_TEMPLATE = f"https://{config.WEB_APP_HOST}/" + "auth/callback/{service}/"
|
||||
|
||||
|
||||
async def get_authorization_url(provider: OAuthProvider) -> str:
|
||||
client = get_client(provider)
|
||||
|
||||
return await client.get_authorization_url(
|
||||
redirect_uri=REDIRECT_URI_TEMPLATE.format(
|
||||
service=provider.value
|
||||
),
|
||||
)
|
||||
16
src/modules/web_app/services/oauth/process_callback.py
Normal file
16
src/modules/web_app/services/oauth/process_callback.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from domain.auth import OAuthProvider
|
||||
|
||||
from .providers import get_client
|
||||
from .authorization_url_getter import REDIRECT_URI_TEMPLATE
|
||||
|
||||
|
||||
async def process_callback(provider: OAuthProvider, code: str) -> tuple[str, str | None]:
|
||||
client = get_client(provider)
|
||||
token = await client.get_access_token(
|
||||
code,
|
||||
redirect_uri=REDIRECT_URI_TEMPLATE.format(service=provider.value),
|
||||
)
|
||||
|
||||
user_data = await client.get_id_email(token["access_token"])
|
||||
|
||||
return user_data
|
||||
10
src/modules/web_app/services/oauth/providers/__init__.py
Normal file
10
src/modules/web_app/services/oauth/providers/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .enums import OAuthProvider
|
||||
from .twitch import twitch_oauth_client
|
||||
from .getter import get_client
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OAuthProvider",
|
||||
"twitch_oauth_client",
|
||||
"get_client"
|
||||
]
|
||||
11
src/modules/web_app/services/oauth/providers/getter.py
Normal file
11
src/modules/web_app/services/oauth/providers/getter.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from httpx_oauth.oauth2 import OAuth2
|
||||
|
||||
from domain.auth import OAuthProvider
|
||||
from .twitch import twitch_oauth_client
|
||||
|
||||
|
||||
def get_client(provider: OAuthProvider) -> OAuth2:
|
||||
if provider == OAuthProvider.TWITCH:
|
||||
return twitch_oauth_client
|
||||
else:
|
||||
raise NotImplementedError("Provider is not implemented")
|
||||
34
src/modules/web_app/services/oauth/providers/twitch.py
Normal file
34
src/modules/web_app/services/oauth/providers/twitch.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from twitchAPI.twitch import Twitch, AuthScope
|
||||
from twitchAPI.helper import first
|
||||
|
||||
from httpx_oauth.oauth2 import OAuth2
|
||||
|
||||
from core.config import config
|
||||
|
||||
|
||||
class TwithOAuth2(OAuth2):
|
||||
async def get_id_email(self, token: str):
|
||||
twitch_client = Twitch(config.TWITCH_CLIENT_ID, config.TWITCH_CLIENT_SECRET)
|
||||
twitch_client.auto_refresh_auth = False
|
||||
|
||||
await twitch_client.set_user_authentication(
|
||||
token,
|
||||
[AuthScope.USER_READ_EMAIL],
|
||||
validate=True
|
||||
)
|
||||
|
||||
me = await first(twitch_client.get_users())
|
||||
|
||||
if me is None:
|
||||
raise Exception("Failed to get user data")
|
||||
|
||||
return me.id, me.email
|
||||
|
||||
|
||||
twitch_oauth_client = TwithOAuth2(
|
||||
config.TWITCH_CLIENT_ID,
|
||||
config.TWITCH_CLIENT_SECRET,
|
||||
"https://id.twitch.tv/oauth2/authorize",
|
||||
"https://id.twitch.tv/oauth2/token",
|
||||
base_scopes=[AuthScope.USER_READ_EMAIL.value],
|
||||
)
|
||||
@@ -1,68 +1,32 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from twitchAPI.twitch import Twitch, AuthScope
|
||||
from twitchAPI.helper import first
|
||||
from httpx_oauth.oauth2 import OAuth2
|
||||
|
||||
from core.config import config
|
||||
from domain.auth import OAuthProvider, OAuthData
|
||||
from domain.users import CreateUser
|
||||
from modules.web_app.services.oauth.process_callback import process_callback
|
||||
from modules.web_app.services.oauth.authorization_url_getter import get_authorization_url as gen_auth_link
|
||||
from modules.web_app.serializers.auth import GetAuthorizationUrlResponse
|
||||
from repositories.users import UserRepository
|
||||
|
||||
|
||||
auth_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
class TwithOAuth2(OAuth2):
|
||||
async def get_id_email(self, token: str):
|
||||
twitch_client = Twitch(config.TWITCH_CLIENT_ID, config.TWITCH_CLIENT_SECRET)
|
||||
twitch_client.auto_refresh_auth = False
|
||||
@auth_router.get("/get_authorization_url/{provider}/")
|
||||
async def get_authorization_url(provider: OAuthProvider) -> GetAuthorizationUrlResponse:
|
||||
link = await gen_auth_link(provider)
|
||||
|
||||
await twitch_client.set_user_authentication(
|
||||
token,
|
||||
[AuthScope.USER_READ_EMAIL],
|
||||
validate=True
|
||||
return GetAuthorizationUrlResponse(authorization_url=link)
|
||||
|
||||
|
||||
@auth_router.get("/callback/{provider}/")
|
||||
async def callback(provider: OAuthProvider, code: str):
|
||||
user_data = await process_callback(provider, code)
|
||||
|
||||
user = await UserRepository.get_or_create_user(
|
||||
CreateUser(
|
||||
oauths={provider: OAuthData(id=user_data[0], email=user_data[1])},
|
||||
is_admin=False,
|
||||
)
|
||||
)
|
||||
|
||||
me = await first(twitch_client.get_users())
|
||||
|
||||
if me is None:
|
||||
raise Exception("Failed to get user data")
|
||||
|
||||
return me.id, me.email
|
||||
|
||||
|
||||
twitch_oauth = TwithOAuth2(
|
||||
config.TWITCH_CLIENT_ID,
|
||||
config.TWITCH_CLIENT_SECRET,
|
||||
"https://id.twitch.tv/oauth2/authorize",
|
||||
"https://id.twitch.tv/oauth2/token",
|
||||
base_scopes=[AuthScope.USER_READ_EMAIL.value],
|
||||
)
|
||||
|
||||
|
||||
REDIRECT_URI_TEMPLATE = f"https://{config.WEB_APP_HOST}/" + "auth/callback/{service}/"
|
||||
|
||||
|
||||
@auth_router.get("/get_authorization_url/{service}/")
|
||||
async def get_authorization_url(service: str):
|
||||
link = None
|
||||
|
||||
if service == "twitch":
|
||||
link = await twitch_oauth.get_authorization_url(
|
||||
redirect_uri=REDIRECT_URI_TEMPLATE.format(service="twitch"),
|
||||
)
|
||||
|
||||
return {"link": link}
|
||||
|
||||
|
||||
@auth_router.get("/callback/{service}/")
|
||||
async def callback(service: str, code: str):
|
||||
user_data = None
|
||||
|
||||
if service == "twitch":
|
||||
token = await twitch_oauth.get_access_token(
|
||||
code,
|
||||
redirect_uri=REDIRECT_URI_TEMPLATE.format(service="twitch"),
|
||||
)
|
||||
|
||||
user_data = await twitch_oauth.get_id_email(token["access_token"])
|
||||
|
||||
return {"user_data": user_data}
|
||||
return {"user": user.model_dump()}
|
||||
|
||||
18
src/repositories/base.py
Normal file
18
src/repositories/base.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import abc
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from core.mongo import mongo_manager
|
||||
|
||||
|
||||
class BaseRepository(abc.ABC):
|
||||
COLLECTION_NAME: str
|
||||
|
||||
@asynccontextmanager
|
||||
@classmethod
|
||||
async def connect(cls):
|
||||
async with mongo_manager.connect() as client:
|
||||
db = client.get_default_database()
|
||||
collection = db[cls.COLLECTION_NAME]
|
||||
|
||||
yield collection
|
||||
@@ -1,17 +1,14 @@
|
||||
from domain.streamers import StreamerConfig
|
||||
|
||||
from core.mongo import mongo_manager
|
||||
from .base import BaseRepository
|
||||
|
||||
|
||||
class StreamerConfigRepository:
|
||||
class StreamerConfigRepository(BaseRepository):
|
||||
COLLECTION_NAME = "streamers"
|
||||
|
||||
@classmethod
|
||||
async def get_by_twitch_id(cls, twitch_id: int) -> StreamerConfig:
|
||||
async with mongo_manager.connect() as client:
|
||||
db = client.get_default_database()
|
||||
collection = db[cls.COLLECTION_NAME]
|
||||
|
||||
async with cls.connect() as collection:
|
||||
doc = await collection.find_one({"twitch.id": twitch_id})
|
||||
if doc is None:
|
||||
raise ValueError(f"Streamer with twitch id {twitch_id} not found")
|
||||
@@ -34,10 +31,7 @@ class StreamerConfigRepository:
|
||||
"integrations.discord.games_list.channel_id"
|
||||
] = integration_discord_games_list_channel_id
|
||||
|
||||
async with mongo_manager.connect() as client:
|
||||
db = client.get_default_database()
|
||||
collection = db[cls.COLLECTION_NAME]
|
||||
|
||||
async with cls.connect() as collection:
|
||||
doc = await collection.find_one(filters)
|
||||
if doc is None:
|
||||
return None
|
||||
@@ -46,9 +40,6 @@ class StreamerConfigRepository:
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[StreamerConfig]:
|
||||
async with mongo_manager.connect() as client:
|
||||
db = client.get_default_database()
|
||||
collection = db[cls.COLLECTION_NAME]
|
||||
|
||||
async with cls.connect() as collection:
|
||||
cursor = await collection.find()
|
||||
return [StreamerConfig(**doc) async for doc in cursor]
|
||||
|
||||
29
src/repositories/users.py
Normal file
29
src/repositories/users.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from domain.users import CreateUser, User
|
||||
|
||||
from .base import BaseRepository
|
||||
|
||||
|
||||
class UserRepository(BaseRepository):
|
||||
COLLECTION_NAME = "users"
|
||||
|
||||
@classmethod
|
||||
async def get_or_create_user(cls, newUser: CreateUser) -> User:
|
||||
filter_data = {}
|
||||
|
||||
for provider, data in newUser.oauths.items():
|
||||
filter_data[f"oauths.{provider}.id"] = data.id
|
||||
|
||||
async with cls.connect() as collection:
|
||||
await collection.update_one(
|
||||
filter_data,
|
||||
{
|
||||
"$setOnInsert": {
|
||||
**newUser.model_dump(),
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
user = await collection.find_one(filter_data)
|
||||
|
||||
return User(**user)
|
||||
Reference in New Issue
Block a user