import gym
from PIL import Image
BABYAI_ACTION_SPACE = [
"turn left",
"turn right",
"go forward",
"pick up",
"drop",
"toggle",
]
[docs]
class BabyAITextCleanLangWrapper(gym.Wrapper):
[docs]
def __init__(self, env, vlm=False, **kwargs):
super().__init__(env)
self.language_action_space = BABYAI_ACTION_SPACE[:]
self._mission = None
self.progression = 0.0
self.renderer = self.env.render("rgb_array")
@property
def interleaving_token(self):
return self._interleaving_token
@property
def default_action(self):
return "go forward"
[docs]
def get_text_action(self, action):
return self.language_action_space[action.value]
[docs]
def get_prompt(self, obs, infos):
image = Image.fromarray(self.env.get_obs_render(obs["image"])).convert("RGB")
def _form_prompt(description):
return "\n".join([d.replace("You see ", "") for d in description])
prompt = _form_prompt(infos["descriptions"])
return prompt, image
[docs]
def reset(self):
obs, infos = self.env.reset()
prompt, image = self.get_prompt(obs, infos)
self._mission = obs["mission"]
# Following the convention from NetHack Language Wrapper for specifying
# short term vs long term context here. There is no equivalent long term
# context like e.g. inventory in BabyAI-Text.
obs["text"] = {"long_term_context": prompt, "short_term_context": ""}
obs["image"] = image
return obs
[docs]
def step(self, action):
action_int = self.language_action_space.index(action)
obs, reward, done, infos = self.env.step(action_int)
if reward > 0:
self.progression = 1.0
prompt, image = self.get_prompt(obs, infos)
obs["text"] = {"long_term_context": prompt, "short_term_context": ""}
obs["image"] = image
return obs, reward, done, infos
[docs]
def get_stats(self):
# No special stats tracking implemented for now
return {"mission": self._mission, "progression": self.progression}