mirror of
https://github.com/kurbezz/discord-bot.git
synced 2026-03-03 14:00:46 +01:00
Update
This commit is contained in:
12
src/domain/auth.py
Normal file
12
src/domain/auth.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProvider(StrEnum):
|
||||||
|
TWITCH = "twitch"
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthData(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str | None
|
||||||
17
src/domain/users.py
Normal file
17
src/domain/users.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from domain.auth import OAuthProvider, OAuthData
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
id: str
|
||||||
|
|
||||||
|
oauths: dict[OAuthProvider, OAuthData]
|
||||||
|
|
||||||
|
is_admin: bool
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUser(BaseModel):
|
||||||
|
oauths: dict[OAuthProvider, OAuthData]
|
||||||
|
|
||||||
|
is_admin: bool = False
|
||||||
5
src/modules/web_app/serializers/auth.py
Normal file
5
src/modules/web_app/serializers/auth.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GetAuthorizationUrlResponse(BaseModel):
|
||||||
|
authorization_url: str
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
from core.config import config
|
||||||
|
|
||||||
|
from domain.auth import OAuthProvider
|
||||||
|
|
||||||
|
from .providers import get_client
|
||||||
|
|
||||||
|
|
||||||
|
REDIRECT_URI_TEMPLATE = f"https://{config.WEB_APP_HOST}/" + "auth/callback/{service}/"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_authorization_url(provider: OAuthProvider) -> str:
|
||||||
|
client = get_client(provider)
|
||||||
|
|
||||||
|
return await client.get_authorization_url(
|
||||||
|
redirect_uri=REDIRECT_URI_TEMPLATE.format(
|
||||||
|
service=provider.value
|
||||||
|
),
|
||||||
|
)
|
||||||
16
src/modules/web_app/services/oauth/process_callback.py
Normal file
16
src/modules/web_app/services/oauth/process_callback.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from domain.auth import OAuthProvider
|
||||||
|
|
||||||
|
from .providers import get_client
|
||||||
|
from .authorization_url_getter import REDIRECT_URI_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
|
async def process_callback(provider: OAuthProvider, code: str) -> tuple[str, str | None]:
|
||||||
|
client = get_client(provider)
|
||||||
|
token = await client.get_access_token(
|
||||||
|
code,
|
||||||
|
redirect_uri=REDIRECT_URI_TEMPLATE.format(service=provider.value),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_data = await client.get_id_email(token["access_token"])
|
||||||
|
|
||||||
|
return user_data
|
||||||
10
src/modules/web_app/services/oauth/providers/__init__.py
Normal file
10
src/modules/web_app/services/oauth/providers/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from .enums import OAuthProvider
|
||||||
|
from .twitch import twitch_oauth_client
|
||||||
|
from .getter import get_client
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OAuthProvider",
|
||||||
|
"twitch_oauth_client",
|
||||||
|
"get_client"
|
||||||
|
]
|
||||||
11
src/modules/web_app/services/oauth/providers/getter.py
Normal file
11
src/modules/web_app/services/oauth/providers/getter.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from httpx_oauth.oauth2 import OAuth2
|
||||||
|
|
||||||
|
from domain.auth import OAuthProvider
|
||||||
|
from .twitch import twitch_oauth_client
|
||||||
|
|
||||||
|
|
||||||
|
def get_client(provider: OAuthProvider) -> OAuth2:
|
||||||
|
if provider == OAuthProvider.TWITCH:
|
||||||
|
return twitch_oauth_client
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Provider is not implemented")
|
||||||
34
src/modules/web_app/services/oauth/providers/twitch.py
Normal file
34
src/modules/web_app/services/oauth/providers/twitch.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from twitchAPI.twitch import Twitch, AuthScope
|
||||||
|
from twitchAPI.helper import first
|
||||||
|
|
||||||
|
from httpx_oauth.oauth2 import OAuth2
|
||||||
|
|
||||||
|
from core.config import config
|
||||||
|
|
||||||
|
|
||||||
|
class TwithOAuth2(OAuth2):
|
||||||
|
async def get_id_email(self, token: str):
|
||||||
|
twitch_client = Twitch(config.TWITCH_CLIENT_ID, config.TWITCH_CLIENT_SECRET)
|
||||||
|
twitch_client.auto_refresh_auth = False
|
||||||
|
|
||||||
|
await twitch_client.set_user_authentication(
|
||||||
|
token,
|
||||||
|
[AuthScope.USER_READ_EMAIL],
|
||||||
|
validate=True
|
||||||
|
)
|
||||||
|
|
||||||
|
me = await first(twitch_client.get_users())
|
||||||
|
|
||||||
|
if me is None:
|
||||||
|
raise Exception("Failed to get user data")
|
||||||
|
|
||||||
|
return me.id, me.email
|
||||||
|
|
||||||
|
|
||||||
|
twitch_oauth_client = TwithOAuth2(
|
||||||
|
config.TWITCH_CLIENT_ID,
|
||||||
|
config.TWITCH_CLIENT_SECRET,
|
||||||
|
"https://id.twitch.tv/oauth2/authorize",
|
||||||
|
"https://id.twitch.tv/oauth2/token",
|
||||||
|
base_scopes=[AuthScope.USER_READ_EMAIL.value],
|
||||||
|
)
|
||||||
@@ -1,68 +1,32 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from twitchAPI.twitch import Twitch, AuthScope
|
from domain.auth import OAuthProvider, OAuthData
|
||||||
from twitchAPI.helper import first
|
from domain.users import CreateUser
|
||||||
from httpx_oauth.oauth2 import OAuth2
|
from modules.web_app.services.oauth.process_callback import process_callback
|
||||||
|
from modules.web_app.services.oauth.authorization_url_getter import get_authorization_url as gen_auth_link
|
||||||
from core.config import config
|
from modules.web_app.serializers.auth import GetAuthorizationUrlResponse
|
||||||
|
from repositories.users import UserRepository
|
||||||
|
|
||||||
|
|
||||||
auth_router = APIRouter(prefix="/auth", tags=["auth"])
|
auth_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
class TwithOAuth2(OAuth2):
|
@auth_router.get("/get_authorization_url/{provider}/")
|
||||||
async def get_id_email(self, token: str):
|
async def get_authorization_url(provider: OAuthProvider) -> GetAuthorizationUrlResponse:
|
||||||
twitch_client = Twitch(config.TWITCH_CLIENT_ID, config.TWITCH_CLIENT_SECRET)
|
link = await gen_auth_link(provider)
|
||||||
twitch_client.auto_refresh_auth = False
|
|
||||||
|
|
||||||
await twitch_client.set_user_authentication(
|
return GetAuthorizationUrlResponse(authorization_url=link)
|
||||||
token,
|
|
||||||
[AuthScope.USER_READ_EMAIL],
|
|
||||||
validate=True
|
@auth_router.get("/callback/{provider}/")
|
||||||
|
async def callback(provider: OAuthProvider, code: str):
|
||||||
|
user_data = await process_callback(provider, code)
|
||||||
|
|
||||||
|
user = await UserRepository.get_or_create_user(
|
||||||
|
CreateUser(
|
||||||
|
oauths={provider: OAuthData(id=user_data[0], email=user_data[1])},
|
||||||
|
is_admin=False,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
me = await first(twitch_client.get_users())
|
return {"user": user.model_dump()}
|
||||||
|
|
||||||
if me is None:
|
|
||||||
raise Exception("Failed to get user data")
|
|
||||||
|
|
||||||
return me.id, me.email
|
|
||||||
|
|
||||||
|
|
||||||
twitch_oauth = TwithOAuth2(
|
|
||||||
config.TWITCH_CLIENT_ID,
|
|
||||||
config.TWITCH_CLIENT_SECRET,
|
|
||||||
"https://id.twitch.tv/oauth2/authorize",
|
|
||||||
"https://id.twitch.tv/oauth2/token",
|
|
||||||
base_scopes=[AuthScope.USER_READ_EMAIL.value],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
REDIRECT_URI_TEMPLATE = f"https://{config.WEB_APP_HOST}/" + "auth/callback/{service}/"
|
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/get_authorization_url/{service}/")
|
|
||||||
async def get_authorization_url(service: str):
|
|
||||||
link = None
|
|
||||||
|
|
||||||
if service == "twitch":
|
|
||||||
link = await twitch_oauth.get_authorization_url(
|
|
||||||
redirect_uri=REDIRECT_URI_TEMPLATE.format(service="twitch"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"link": link}
|
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/callback/{service}/")
|
|
||||||
async def callback(service: str, code: str):
|
|
||||||
user_data = None
|
|
||||||
|
|
||||||
if service == "twitch":
|
|
||||||
token = await twitch_oauth.get_access_token(
|
|
||||||
code,
|
|
||||||
redirect_uri=REDIRECT_URI_TEMPLATE.format(service="twitch"),
|
|
||||||
)
|
|
||||||
|
|
||||||
user_data = await twitch_oauth.get_id_email(token["access_token"])
|
|
||||||
|
|
||||||
return {"user_data": user_data}
|
|
||||||
|
|||||||
18
src/repositories/base.py
Normal file
18
src/repositories/base.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from core.mongo import mongo_manager
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRepository(abc.ABC):
|
||||||
|
COLLECTION_NAME: str
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
@classmethod
|
||||||
|
async def connect(cls):
|
||||||
|
async with mongo_manager.connect() as client:
|
||||||
|
db = client.get_default_database()
|
||||||
|
collection = db[cls.COLLECTION_NAME]
|
||||||
|
|
||||||
|
yield collection
|
||||||
@@ -1,17 +1,14 @@
|
|||||||
from domain.streamers import StreamerConfig
|
from domain.streamers import StreamerConfig
|
||||||
|
|
||||||
from core.mongo import mongo_manager
|
from .base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
class StreamerConfigRepository:
|
class StreamerConfigRepository(BaseRepository):
|
||||||
COLLECTION_NAME = "streamers"
|
COLLECTION_NAME = "streamers"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_twitch_id(cls, twitch_id: int) -> StreamerConfig:
|
async def get_by_twitch_id(cls, twitch_id: int) -> StreamerConfig:
|
||||||
async with mongo_manager.connect() as client:
|
async with cls.connect() as collection:
|
||||||
db = client.get_default_database()
|
|
||||||
collection = db[cls.COLLECTION_NAME]
|
|
||||||
|
|
||||||
doc = await collection.find_one({"twitch.id": twitch_id})
|
doc = await collection.find_one({"twitch.id": twitch_id})
|
||||||
if doc is None:
|
if doc is None:
|
||||||
raise ValueError(f"Streamer with twitch id {twitch_id} not found")
|
raise ValueError(f"Streamer with twitch id {twitch_id} not found")
|
||||||
@@ -34,10 +31,7 @@ class StreamerConfigRepository:
|
|||||||
"integrations.discord.games_list.channel_id"
|
"integrations.discord.games_list.channel_id"
|
||||||
] = integration_discord_games_list_channel_id
|
] = integration_discord_games_list_channel_id
|
||||||
|
|
||||||
async with mongo_manager.connect() as client:
|
async with cls.connect() as collection:
|
||||||
db = client.get_default_database()
|
|
||||||
collection = db[cls.COLLECTION_NAME]
|
|
||||||
|
|
||||||
doc = await collection.find_one(filters)
|
doc = await collection.find_one(filters)
|
||||||
if doc is None:
|
if doc is None:
|
||||||
return None
|
return None
|
||||||
@@ -46,9 +40,6 @@ class StreamerConfigRepository:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def all(cls) -> list[StreamerConfig]:
|
async def all(cls) -> list[StreamerConfig]:
|
||||||
async with mongo_manager.connect() as client:
|
async with cls.connect() as collection:
|
||||||
db = client.get_default_database()
|
|
||||||
collection = db[cls.COLLECTION_NAME]
|
|
||||||
|
|
||||||
cursor = await collection.find()
|
cursor = await collection.find()
|
||||||
return [StreamerConfig(**doc) async for doc in cursor]
|
return [StreamerConfig(**doc) async for doc in cursor]
|
||||||
|
|||||||
29
src/repositories/users.py
Normal file
29
src/repositories/users.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from domain.users import CreateUser, User
|
||||||
|
|
||||||
|
from .base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepository(BaseRepository):
|
||||||
|
COLLECTION_NAME = "users"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_or_create_user(cls, newUser: CreateUser) -> User:
|
||||||
|
filter_data = {}
|
||||||
|
|
||||||
|
for provider, data in newUser.oauths.items():
|
||||||
|
filter_data[f"oauths.{provider}.id"] = data.id
|
||||||
|
|
||||||
|
async with cls.connect() as collection:
|
||||||
|
await collection.update_one(
|
||||||
|
filter_data,
|
||||||
|
{
|
||||||
|
"$setOnInsert": {
|
||||||
|
**newUser.model_dump(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await collection.find_one(filter_data)
|
||||||
|
|
||||||
|
return User(**user)
|
||||||
Reference in New Issue
Block a user