diff --git a/src/modules/stream_notifications/messages_proc.py b/src/modules/stream_notifications/messages_proc.py index dc03679..de36b30 100644 --- a/src/modules/stream_notifications/messages_proc.py +++ b/src/modules/stream_notifications/messages_proc.py @@ -93,25 +93,26 @@ class MessageEvent(BaseModel): -async def get_completion(message: str, reply_to: str | None = None) -> str: - logger.info(f"Getting completion for message: {message} with reply to: {reply_to}") +async def get_completion(messages: list[dict]) -> str: + logger.info(f"Getting completion for message: {messages}") - messages = [ + data_messages = [ + *( + { + "role": "assistant" if message["user"] == "kurbezz" else "user", + "content": message["text"] + } + for message in messages + ), { "role": "system", "content": "Don't use markdown! Don't use blocked words on Twitch! Make answers short and clear!" } ] - if reply_to: - messages.append({ - "role": "assistant", - "content": reply_to - }) - messages.append({ "role": "user", - "content": message + "content": data_messages }) async with AsyncClient() as client: @@ -140,10 +141,46 @@ class MessagesProc: "kurbezz", ] + MESSAGE_LIMIT = 1000 + MESSAGE_HISTORY = [] + + @classmethod + def update_message_history(cls, id: str, text: str, user: str, reply_to: str | None = None): + cls.MESSAGE_HISTORY.append({ + "id": id, + "text": text, + "user": user, + "reply_to": reply_to + }) + + if len(cls.MESSAGE_HISTORY) > cls.MESSAGE_LIMIT: + cls.MESSAGE_HISTORY = cls.MESSAGE_HISTORY[-cls.MESSAGE_LIMIT:] + + @classmethod + def get_message_history_with_thread(cls, thread_id: str | None, current_deep: int = 5) -> list[dict]: + if thread_id is None: + return [] + + if current_deep > 5: + return [] + + message = next((msg for msg in cls.MESSAGE_HISTORY if msg["id"] == thread_id), None) + if message is None: + return [] + + return cls.get_message_history_with_thread(message["reply_to"], current_deep + 1) + [message] + @classmethod async def on_message(cls, event: MessageEvent): logging.info(f"Received message: {event}") + cls.update_message_history( + id=event.message_id, + text=event.message.text, + user=event.chatter_user_login, + reply_to=event.reply.parent_message_id if event.reply is not None else None + ) + if event.chatter_user_name == "pahangor": return @@ -164,10 +201,8 @@ class MessagesProc: event.reply and event.reply.parent_message_body.lower().startswith("!ai") ): try: - completion = await get_completion( - event.message.text.replace("!ai", "").strip(), - reply_to=event.reply.parent_message_body.replace("!ai", "").strip() if event.reply is not None else None - ) + messages = cls.get_message_history_with_thread(event.message_id) + completion = await get_completion(messages) max_length = 255 completion_parts = [completion[i:i + max_length] for i in range(0, len(completion), max_length)] @@ -197,6 +232,7 @@ class MessagesProc: "булат" in event.message.text.lower()): try: + messages = cls.get_message_history_with_thread(event.message_id) completion = await get_completion(event.message.text) max_length = 255