Source code for iclbench.environments.env_wrapper

import gym


[docs] class EnvWrapper(gym.Wrapper):
[docs] def __init__(self, env, env_name, task_name): super().__init__(env) self.env_name = env_name self.task_name = task_name self.failed_candidates = []
@property def max_steps(self): return self.env.max_steps
[docs] def reset(self): obs = self.env.reset() return self._process_observation(obs)
[docs] def step(self, action): obs, reward, done, info = self.env.step(action) processed_obs = self._process_observation(obs) return processed_obs, reward, done, info
def _process_observation(self, obs): if self.env_name in ["nle", "minihack"]: obs = obs elif self.env_name == "babyai": obs = obs elif self.env_name == "textworld": obs = obs elif self.env_name == "babaisai": obs = obs elif self.env_name == "crafter": obs = obs elif self.env_name == "craftax": obs = obs else: raise ValueError(f"Unknown environment: {self.env_name}") return obs @property def actions(self): # This property should return the list of available actions return self.env.actions if hasattr(self.env, "actions") else list(range(len(self.env.action_space)))
[docs] def get_text_action(self, action): return self.env.get_text_action(action)
[docs] def get_instruction_prompt(self, instructions=None): if self.env_name == "nle": from iclbench.environments.nle import get_instruction_prompt return get_instruction_prompt() elif self.env_name == "minihack": from iclbench.environments.minihack import get_instruction_prompt return get_instruction_prompt(self.env, self.task_name) elif self.env_name == "babyai": from iclbench.environments.babyai_text import get_instruction_prompt return get_instruction_prompt(self.env, mission=instructions) elif self.env_name == "textworld": from iclbench.environments.textworld import get_instruction_prompt return get_instruction_prompt(self.env, self.task_name) elif self.env_name == "babaisai": from iclbench.environments.baba_is_ai import get_instruction_prompt return get_instruction_prompt(self.env, self.task_name) elif self.env_name == "crafter": from iclbench.environments.crafter import get_instruction_prompt return get_instruction_prompt(self.task_name) elif self.env_name == "craftax": from iclbench.environments.craftax import get_instruction_prompt return get_instruction_prompt(self.task_name) else: raise ValueError(f"Unknown environment: {self.env_namee}")
[docs] def check_action_validity(self, candidate_action): valid_action = None if candidate_action in self.env.language_action_space: valid_action = candidate_action else: valid_action = self.env.default_action self.failed_candidates.append(candidate_action) return valid_action
[docs] def get_stats(self): return self.env.get_stats()