This commit is contained in:
2025-02-18 23:42:57 +01:00
parent f9c1f7e77b
commit aec3939816

View File

@@ -93,25 +93,26 @@ class MessageEvent(BaseModel):
async def get_completion(message: str, reply_to: str | None = None) -> str: async def get_completion(messages: list[dict]) -> str:
logger.info(f"Getting completion for message: {message} with reply to: {reply_to}") logger.info(f"Getting completion for message: {messages}")
messages = [ data_messages = [
*(
{
"role": "assistant" if message["user"] == "kurbezz" else "user",
"content": message["text"]
}
for message in messages
),
{ {
"role": "system", "role": "system",
"content": "Don't use markdown! Don't use blocked words on Twitch! Make answers short and clear!" "content": "Don't use markdown! Don't use blocked words on Twitch! Make answers short and clear!"
} }
] ]
if reply_to:
messages.append({
"role": "assistant",
"content": reply_to
})
messages.append({ messages.append({
"role": "user", "role": "user",
"content": message "content": data_messages
}) })
async with AsyncClient() as client: async with AsyncClient() as client:
@@ -140,10 +141,46 @@ class MessagesProc:
"kurbezz", "kurbezz",
] ]
MESSAGE_LIMIT = 1000
MESSAGE_HISTORY = []
@classmethod
def update_message_history(cls, id: str, text: str, user: str, reply_to: str | None = None):
cls.MESSAGE_HISTORY.append({
"id": id,
"text": text,
"user": user,
"reply_to": reply_to
})
if len(cls.MESSAGE_HISTORY) > cls.MESSAGE_LIMIT:
cls.MESSAGE_HISTORY = cls.MESSAGE_HISTORY[-cls.MESSAGE_LIMIT:]
@classmethod
def get_message_history_with_thread(cls, thread_id: str | None, current_deep: int = 5) -> list[dict]:
if thread_id is None:
return []
if current_deep > 5:
return []
message = next((msg for msg in cls.MESSAGE_HISTORY if msg["id"] == thread_id), None)
if message is None:
return []
return cls.get_message_history_with_thread(message["reply_to"], current_deep + 1) + [message]
@classmethod @classmethod
async def on_message(cls, event: MessageEvent): async def on_message(cls, event: MessageEvent):
logging.info(f"Received message: {event}") logging.info(f"Received message: {event}")
cls.update_message_history(
id=event.message_id,
text=event.message.text,
user=event.chatter_user_login,
reply_to=event.reply.parent_message_id if event.reply is not None else None
)
if event.chatter_user_name == "pahangor": if event.chatter_user_name == "pahangor":
return return
@@ -164,10 +201,8 @@ class MessagesProc:
event.reply and event.reply.parent_message_body.lower().startswith("!ai") event.reply and event.reply.parent_message_body.lower().startswith("!ai")
): ):
try: try:
completion = await get_completion( messages = cls.get_message_history_with_thread(event.message_id)
event.message.text.replace("!ai", "").strip(), completion = await get_completion(messages)
reply_to=event.reply.parent_message_body.replace("!ai", "").strip() if event.reply is not None else None
)
max_length = 255 max_length = 255
completion_parts = [completion[i:i + max_length] for i in range(0, len(completion), max_length)] completion_parts = [completion[i:i + max_length] for i in range(0, len(completion), max_length)]
@@ -197,6 +232,7 @@ class MessagesProc:
"булат" in event.message.text.lower()): "булат" in event.message.text.lower()):
try: try:
messages = cls.get_message_history_with_thread(event.message_id)
completion = await get_completion(event.message.text) completion = await get_completion(event.message.text)
max_length = 255 max_length = 255