# Here we should have an environment manager function that can be used to instantiate
# environments with the correct wrappers.
import random
import gym
import numpy as np
from gym import spaces
from iclbench.environments.env_wrapper import EnvWrapper
[docs]
def make_env(env_name, task, config):
"""
Creates and initializes an environment based on the specified environment name and task.
This function supports multiple environment types, each with its own configuration
requirements. The function will return a wrapped environment suitable for use with
agents in the context of the ICLBench framework.
Args:
env_name (str): The name of the environment to create. Supported values include:
- "nle"
- "minihack"
- "babyai"
- "crafter"
- "craftax"
- "textworld"
- "babaisai"
task (str): The specific task to be performed within the environment.
config (Config): An object containing configuration settings, which must include
environment-specific keys such as:
- envs.nle_kwargs (dict): Arguments specific to the NLE environment.
- envs.minihack_kwargs (dict): Arguments specific to the MiniHack environment.
- envs.babyai_kwargs (dict): Arguments specific to the BabyAI environment.
- envs.crafter_kwargs (dict): Arguments specific to the Crafter environment.
- envs.craftax_kwargs (dict): Arguments specific to the Craftax environment.
- envs.textworld_kwargs (dict): Arguments specific to the TextWorld environment.
- envs.babaisai_kwargs (dict): Arguments specific to the Baba Is AI environment.
Returns:
EnvWrapper: A wrapped environment instance that includes language processing capabilities
and task-specific functionality.
Raises:
ValueError: If the provided environment name is not recognized.
"""
if env_name == "nle":
from iclbench.environments.nle import NLELanguageWrapper
nle_kwargs = dict(config.envs.nle_kwargs)
skip_more = nle_kwargs.pop("skip_more", False)
vlm = True if config.agent.max_image_history > 0 else False
env = gym.make(task, **nle_kwargs)
base_env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
elif env_name == "minihack":
import minihack
from iclbench.environments.nle import NLELanguageWrapper
minihack_kwargs = dict(config.envs.minihack_kwargs)
skip_more = minihack_kwargs.pop("skip_more", False)
vlm = True if config.agent.max_image_history > 0 else False
env = gym.make(
task,
observation_keys=[
"glyphs",
"blstats",
"tty_chars",
"inv_letters",
"inv_strs",
"tty_cursor",
"tty_colors",
],
**minihack_kwargs,
)
base_env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
elif env_name == "babyai":
import babyai_text
from iclbench.environments.babyai_text import BabyAITextCleanLangWrapper
base_task, goal = task.split("/")
while 1:
env = gym.make(base_task)
if env.env.action_kinds[0].replace(" ", "_") == goal:
break
base_env = BabyAITextCleanLangWrapper(env, **config.envs.babyai_kwargs)
elif env_name == "crafter":
import crafter
from iclbench.environments.crafter import CrafterLanguageWrapper
crafter_kwargs = dict(config.envs.crafter_kwargs)
max_episode_steps = crafter_kwargs.pop("max_episode_steps", 2)
for param in ["area", "view", "size"]:
if param in crafter_kwargs:
crafter_kwargs[param] = tuple(crafter_kwargs[param])
env = crafter.Env(**crafter_kwargs)
base_env = CrafterLanguageWrapper(env, task, max_episode_steps=max_episode_steps)
elif env_name == "craftax":
from iclbench.environments.craftax import CraftaxLanguageWrapper
base_env = CraftaxLanguageWrapper(task, **config.envs.craftax_kwargs)
elif env_name == "textworld":
from iclbench.environments.textworld import global_textworld_context
textworld_context = global_textworld_context(tasks=config.tasks.textworld_tasks, **config.envs.textworld_kwargs)
base_env = textworld_context(task, **config.envs.env_kwargs)
elif env_name == "babaisai":
from baba import make
from iclbench.environments.baba_is_ai import BabaIsAIWrapper
base_env = BabaIsAIWrapper(make(task, **config.envs.babaisai_kwargs))
else:
raise ValueError(f"Unknown environment: {env_name}")
return EnvWrapper(base_env, env_name, task)
[docs]
class Strings(spaces.Space):
[docs]
def __init__(self, values, seed=None):
super().__init__((len(values),), str, seed)
self._dict = {value: i for i, value in enumerate(values)}
self._values = values
[docs]
def sample(self):
return self.np_random.choice(self._values)
[docs]
def map(self, action):
return self._dict[action]
[docs]
def contains(self, value):
return value in self._dict
def __iter__(self):
return self._values.__iter__()