This commit is contained in:
2025-02-18 23:42:57 +01:00
parent f9c1f7e77b
commit aec3939816

View File

@@ -93,25 +93,26 @@ class MessageEvent(BaseModel):
async def get_completion(message: str, reply_to: str | None = None) -> str:
logger.info(f"Getting completion for message: {message} with reply to: {reply_to}")
async def get_completion(messages: list[dict]) -> str:
logger.info(f"Getting completion for message: {messages}")
messages = [
data_messages = [
*(
{
"role": "assistant" if message["user"] == "kurbezz" else "user",
"content": message["text"]
}
for message in messages
),
{
"role": "system",
"content": "Don't use markdown! Don't use blocked words on Twitch! Make answers short and clear!"
}
]
if reply_to:
messages.append({
"role": "assistant",
"content": reply_to
})
messages.append({
"role": "user",
"content": message
"content": data_messages
})
async with AsyncClient() as client:
@@ -140,10 +141,46 @@ class MessagesProc:
"kurbezz",
]
MESSAGE_LIMIT = 1000
MESSAGE_HISTORY = []
@classmethod
def update_message_history(cls, id: str, text: str, user: str, reply_to: str | None = None):
cls.MESSAGE_HISTORY.append({
"id": id,
"text": text,
"user": user,
"reply_to": reply_to
})
if len(cls.MESSAGE_HISTORY) > cls.MESSAGE_LIMIT:
cls.MESSAGE_HISTORY = cls.MESSAGE_HISTORY[-cls.MESSAGE_LIMIT:]
@classmethod
def get_message_history_with_thread(cls, thread_id: str | None, current_deep: int = 5) -> list[dict]:
if thread_id is None:
return []
if current_deep > 5:
return []
message = next((msg for msg in cls.MESSAGE_HISTORY if msg["id"] == thread_id), None)
if message is None:
return []
return cls.get_message_history_with_thread(message["reply_to"], current_deep + 1) + [message]
@classmethod
async def on_message(cls, event: MessageEvent):
logging.info(f"Received message: {event}")
cls.update_message_history(
id=event.message_id,
text=event.message.text,
user=event.chatter_user_login,
reply_to=event.reply.parent_message_id if event.reply is not None else None
)
if event.chatter_user_name == "pahangor":
return
@@ -164,10 +201,8 @@ class MessagesProc:
event.reply and event.reply.parent_message_body.lower().startswith("!ai")
):
try:
completion = await get_completion(
event.message.text.replace("!ai", "").strip(),
reply_to=event.reply.parent_message_body.replace("!ai", "").strip() if event.reply is not None else None
)
messages = cls.get_message_history_with_thread(event.message_id)
completion = await get_completion(messages)
max_length = 255
completion_parts = [completion[i:i + max_length] for i in range(0, len(completion), max_length)]
@@ -197,6 +232,7 @@ class MessagesProc:
"булат" in event.message.text.lower()):
try:
messages = cls.get_message_history_with_thread(event.message_id)
completion = await get_completion(event.message.text)
max_length = 255