from iclbench.agents.base import BaseAgent
from typing import List, Optional
[docs]
class Message:
[docs]
def __init__(self, role: str, content: str, attachment: Optional[object] = None):
self.role = role # 'system', 'user', 'assistant'
self.content = content # String content of the message
self.attachment = attachment
def __repr__(self):
return f"Message(role={self.role}, content={self.content}, attachment={self.attachment})"
[docs]
class ICLAgent(BaseAgent):
[docs]
def __init__(self, client_factory, prompt_builder):
super().__init__(client_factory, prompt_builder)
self.client = client_factory()
self.icl_episodes = []
self.icl_events = []
self.cached_icl = False
[docs]
def update_icl_observation(self, obs: dict):
long_term_context = obs["text"].get("long_term_context", "")
self.icl_events.append(
{
"type": "icl_observation",
"text": long_term_context,
}
)
[docs]
def update_icl_action(self, action: str):
self.icl_events.append(
{
"type": "icl_action",
"action": action,
}
)
[docs]
def cache_icl(self):
self.client.cache_icl_demo(self.get_icl_prompt())
self.cached_icl = True
[docs]
def wrap_episode(self):
icl_episode = []
icl_episode.append(
Message(role="user", content=f"****** START OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
)
for event in self.icl_events:
if event["type"] == "icl_observation":
content = "Obesrvation:\n" + event["text"]
message = Message(role="user", content=content)
elif event["type"] == "icl_action":
content = event["action"]
message = Message(role="assistant", content=content)
icl_episode.append(message)
icl_episode.append(
Message(role="user", content=f"****** END OF DEMONSTRATION EPISODE {len(self.icl_episodes) + 1} ******")
)
self.icl_episodes.append(icl_episode)
[docs]
def get_icl_prompt(self) -> List[Message]:
icl_instruction = Message(
role="user",
content=self.prompt_builder.system_prompt.replace(
"PLAY",
"First, observe the demonstrations provided and learn from them!",
),
)
# unroll the wrapped icl episodes messages
icl_messages = [icl_instruction]
for icl_episode in self.icl_episodes:
icl_messages.extend(icl_episode)
end_demo_message = Message(
role="user",
content="****** Now it's your turn to play the game! ******",
)
icl_messages.append(end_demo_message)
return icl_messages
[docs]
def act(self, obs, prev_action=None):
if prev_action:
self.prompt_builder.update_action(prev_action)
self.prompt_builder.update_observation(obs)
if not self.cached_icl:
messages = self.get_icl_prompt()
else:
messages = []
messages.extend(self.prompt_builder.get_prompt(icl_episodes=True))
# Add naive instructions to the last user message
naive_instruction = """
You can only output one of the above actions at a time, and always have to output an action until the episode terminates.
Action:
""".strip()
if messages and messages[-1].role == "user":
messages[-1].content += "\n\n" + naive_instruction
response = self.client.generate(messages)
return response