from collections import defaultdict
import baba
import gym
import numpy as np
from baba.world_object import name_mapping
from PIL import Image
BABAISAI_ACTION_SPACE = [a.name for a in baba.grid.BabaIsYouEnv.Actions]
[docs]
class BabaIsAIWrapper(gym.Wrapper):
[docs]
def __init__(self, env: gym.Env, add_ruleset=True, vlm=False):
super().__init__(env)
self.add_ruleset = add_ruleset
self.language_action_space = BABAISAI_ACTION_SPACE[:]
self.progression = 0.0
self.target_plan = None
@property
def default_action(self):
return BABAISAI_ACTION_SPACE[0]
[docs]
def get_text_action(self, action):
return self.language_action_space[action.value]
[docs]
def get_ruleset(self):
"""
Retrieve and format the ruleset for the current environment.
This method extracts rules from the environment's grid ruleset,
formats them into human-readable strings, and returns them as a
single string with each rule on a new line.
"""
rules = []
for rule in self.env.grid._ruleset["_rule_"]:
# all objects start with f, eg `fwall`, `fkey`...
# are objects that can be manipulated, `wall` is used to indicate end of map
name = rule["object"].removeprefix("f")
named_property = name_mapping[rule["property"]]
rules.append(f"{name} is {named_property}")
return "\n".join(rules)
[docs]
def get_text_observation(self, obs):
"""
Generate a text-based observation of the environment.
This method creates a textual description of the environment,
including the relative positions of various objects with respect
to the player's position (represented by 'baba').
"""
def find_objects(objects):
obj = []
for j in range(0, self.env.height):
for i in range(0, self.env.width):
cell = self.env.grid.get(i, j)
if cell is not None and cell.type in objects:
if cell.type == "rule_object":
name = f"rule `{cell.name}`"
elif cell.type == "rule_is":
name = f"rule `{name_mapping[cell.name]}`"
elif cell.type == "rule_property":
name = f"rule `{name_mapping[cell.property]}`"
else:
name = cell.type
obj.append(((i, j), name))
return obj
def calculate_offsets(reference_position, positions):
reference_position = np.asanyarray(reference_position)
positions = np.asanyarray(positions)
relative_positions = []
for pos in positions:
relative_positions.append(pos - reference_position)
return relative_positions
def form_description(relative_positions):
def steps(v):
return "steps" if v > 1 else "step"
descriptions = []
for pos in relative_positions:
(x, y), name = pos
name = name.removeprefix("f")
x_direction = ""
if x > 0:
x_direction = f"{x} {steps(x)} to the right"
elif x < 0:
x_direction = f"{-x} {steps(x)} to the left"
y_direction = ""
if y > 0:
y_direction = f"{y} {steps(y)} down"
elif y < 0:
y_direction = f"{-y} {steps(y)} up"
description = ""
if x_direction:
description = f"{name} {x_direction}"
if y_direction:
if x_direction:
description += f" and {y_direction}"
else:
description = f"{name} {y_direction}"
descriptions.append(description)
return "\n".join(descriptions)
you = None
for rule in self.env.grid._ruleset["_rule_"]:
named_property = name_mapping[rule["property"]]
if named_property == "you":
you = rule["object"]
# TODO: we need to handle multilpe me
my_position = find_objects([you])[0]
other_positions = find_objects(
[
"fball",
"fwall",
"fdoor",
"fkey",
"rule_object",
"rule_is",
"rule_property",
]
)
offsets = calculate_offsets(my_position[0], [p[0] for p in other_positions])
relative_positions = [(tuple(offset), pos[1]) for offset, pos in zip(offsets, other_positions)]
text_observation = form_description(relative_positions)
return text_observation
[docs]
def textworld_process_obsv(self, textworld_obsv):
ruleset = self.get_ruleset()
text_observation = self.get_text_observation(textworld_obsv)
prompt = ""
if self.add_ruleset:
prompt += f"Active rules:\n{ruleset}\n\n"
prompt += f"Objects on the map:\n{text_observation}"
obs = defaultdict(lambda: None)
obs["text"] = {"long_term_context": prompt, "short_term_context": ""}
image = Image.fromarray(self.env.render(mode="rgb_array")).convert("RGB")
obs["image"] = image
return obs
[docs]
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
self.target_plan = self.env.target_plan
self.progression = 0.0
return self.textworld_process_obsv(obs)
[docs]
def step(self, action):
action_int = self.language_action_space.index(action)
obs, reward, done, info = self.env.step(action_int)
if done and self.env.is_win:
self.progression = 1.0
return self.textworld_process_obsv(obs), reward, done, info
[docs]
def get_stats(self):
return {"target_plan": self.target_plan, "progression": self.progression}