import gym
[docs]
class EnvWrapper(gym.Wrapper):
[docs]
def __init__(self, env, env_name, task_name):
super().__init__(env)
self.env_name = env_name
self.task_name = task_name
self.failed_candidates = []
@property
def max_steps(self):
return self.env.max_steps
[docs]
def reset(self):
obs = self.env.reset()
return self._process_observation(obs)
[docs]
def step(self, action):
obs, reward, done, info = self.env.step(action)
processed_obs = self._process_observation(obs)
return processed_obs, reward, done, info
def _process_observation(self, obs):
if self.env_name in ["nle", "minihack"]:
obs = obs
elif self.env_name == "babyai":
obs = obs
elif self.env_name == "textworld":
obs = obs
elif self.env_name == "babaisai":
obs = obs
elif self.env_name == "crafter":
obs = obs
elif self.env_name == "craftax":
obs = obs
else:
raise ValueError(f"Unknown environment: {self.env_name}")
return obs
@property
def actions(self):
# This property should return the list of available actions
return self.env.actions if hasattr(self.env, "actions") else list(range(len(self.env.action_space)))
[docs]
def get_text_action(self, action):
return self.env.get_text_action(action)
[docs]
def get_instruction_prompt(self, instructions=None):
if self.env_name == "nle":
from iclbench.environments.nle import get_instruction_prompt
return get_instruction_prompt()
elif self.env_name == "minihack":
from iclbench.environments.minihack import get_instruction_prompt
return get_instruction_prompt(self.env, self.task_name)
elif self.env_name == "babyai":
from iclbench.environments.babyai_text import get_instruction_prompt
return get_instruction_prompt(self.env, mission=instructions)
elif self.env_name == "textworld":
from iclbench.environments.textworld import get_instruction_prompt
return get_instruction_prompt(self.env, self.task_name)
elif self.env_name == "babaisai":
from iclbench.environments.baba_is_ai import get_instruction_prompt
return get_instruction_prompt(self.env, self.task_name)
elif self.env_name == "crafter":
from iclbench.environments.crafter import get_instruction_prompt
return get_instruction_prompt(self.task_name)
elif self.env_name == "craftax":
from iclbench.environments.craftax import get_instruction_prompt
return get_instruction_prompt(self.task_name)
else:
raise ValueError(f"Unknown environment: {self.env_namee}")
[docs]
def check_action_validity(self, candidate_action):
valid_action = None
if candidate_action in self.env.language_action_space:
valid_action = candidate_action
else:
valid_action = self.env.default_action
self.failed_candidates.append(candidate_action)
return valid_action
[docs]
def get_stats(self):
return self.env.get_stats()