-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathavatar_agent.py
111 lines (93 loc) · 4.03 KB
/
avatar_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
import os
import sys
import json
import argparse
from typing import Optional, Dict, Any
from datetime import datetime
from llm_client import LlmClient
from prompts import get_system_prompt, get_user_prompt
from image_generator import ImageGenerator
class ConversationMemory:
def __init__(self, max_history: int = 10):
self.max_history = max_history
self.history = []
def add_interaction(self, user_input: str, agent_response: str) -> None:
interaction = {
"timestamp": datetime.now().isoformat(),
"user_input": user_input,
"agent_response": agent_response
} if user_input != "" else {
"timestamp": datetime.now().isoformat(),
"agent_response": agent_response
}
self.history.append(interaction)
if len(self.history) > self.max_history:
self.history = self.history[-self.max_history:]
def get_formatted_history(self, max_entries: Optional[int] = None) -> str:
if not max_entries:
max_entries = self.max_history
entries = self.history[-max_entries:] if max_entries else self.history
formatted_history = ""
for entry in entries:
formatted_history += f"User: {entry['user_input']}\n" if entry.get(
'user_input', '') != '' else ''
formatted_history += f"Assistant: {entry['agent_response']}\n\n"
return formatted_history.strip()
class AvatarGenerator:
counter: int = 0
nickname: str = ''
def __init__(self, config: Dict[str, Any]):
self.config = config
self.memory = ConversationMemory()
self.llm_client = LlmClient(self.config)
self.image_generator = ImageGenerator(config)
def image_path(self) -> str:
path: str = f'output/{self.nickname}_{self.counter}.png'
self.counter += 1
return path
def gen(self, nickname: str) -> None:
self.nickname = nickname
system_prompt = get_system_prompt()
if self.memory.history:
conversation_history = self.memory.get_formatted_history()
system_prompt += f"\n\n# Conversation History\n{conversation_history}"
user_prompt = get_user_prompt(nickname, "")
avatar_description: str = self.llm_client.call(system_prompt,
user_prompt)
print(f"Generating avatar image for '{nickname}' using description...")
try:
self.image_generator.gen(avatar_description, nickname,
self.image_path())
self.memory.add_interaction("", avatar_description)
except Exception as e:
print(f"Error generating avatar image: {e}")
sys.exit(1)
def _continue(feedback: str) -> None:
if self.memory.history:
conversation_history = self.memory.get_formatted_history()
system_prompt = get_system_prompt()
system_prompt += f"\n\n# Conversation History\n{conversation_history}"
user_prompt = get_user_prompt(nickname, feedback)
avatar_description = self.llm_client.call(system_prompt,
user_prompt)
print(f"Generating updated avatar image.")
try:
self.image_generator.gen(
avatar_description,
nickname,
self.image_path(),
)
self.memory.add_interaction(feedback, avatar_description)
print(f"Avatar generated successfully!")
except Exception as e:
print(f"Error generating updated avatar image: {e}")
sys.exit(1)
print(
"\nEnter your feedback or requests to refine the avatar.\nType 'exit' or 'quit' to end the session.\n"
)
while True:
user_input = input("> ")
if user_input.lower() in ['exit', 'quit', 'q']:
sys.exit(0)
_continue(user_input)