diff --git a/src/modules/stream_notifications/messages_proc.py b/src/modules/stream_notifications/messages_proc.py index 0fd7d5c..f27425a 100644 --- a/src/modules/stream_notifications/messages_proc.py +++ b/src/modules/stream_notifications/messages_proc.py @@ -151,10 +151,10 @@ class MessagesProc: cls.MESSAGE_HISTORY = cls.MESSAGE_HISTORY[-cls.MESSAGE_LIMIT:] @classmethod - def get_message_history_with_thread(cls, message_id: str) -> list[dict]: + def get_message_history_with_thread(cls, message_id: str, thread_id: str | None = None) -> list[dict]: logger.info(f"HISTORY: {cls.MESSAGE_HISTORY}") - return [m for m in cls.MESSAGE_HISTORY if m["thread_id"] == message_id or m["id"] == message_id] + return [m for m in cls.MESSAGE_HISTORY if m["thread_id"] == thread_id or m["id"] == message_id] @classmethod async def on_message(cls, event: MessageEvent): @@ -185,7 +185,10 @@ class MessagesProc: if event.message.text.lower().startswith("!ai"): try: - messages = cls.get_message_history_with_thread(event.message_id) + messages = cls.get_message_history_with_thread( + event.message_id, + thread_id=event.reply.thread_message_id if event.reply is not None else None + ) completion = await get_completion(messages) max_length = 255 @@ -216,7 +219,10 @@ class MessagesProc: "булат" in event.message.text.lower()): try: - messages = cls.get_message_history_with_thread(event.message_id) + messages = cls.get_message_history_with_thread( + event.message_id, + thread_id=event.reply.thread_message_id if event.reply is not None else None + ) completion = await get_completion(messages) max_length = 255