import glob
import importlib.resources
import os
from collections import defaultdict
from pathlib import Path
import gym
import textworld
import textworld.gym
workspace_dir = os.path.dirname(importlib.resources.files("iclbench").__str__())
[docs]
class TextWorldFactory:
"""
A factory class for creating TextWorld environments.
This class manages the creation of TextWorld environments for different tasks,
cycling through available games for each task or allowing specific game selection.
"""
_instance = None
def __new__(cls, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.initialize(**kwargs)
return cls._instance
[docs]
def initialize(self, textworld_games_path, tasks, max_episode_steps=40, **kwargs):
self.max_steps = max_episode_steps
textworld_games_path = os.path.join(workspace_dir, textworld_games_path)
self.count = defaultdict(int)
required_kwargs = ["objective", "description", "score", "max_score", "won"]
for kwarg in required_kwargs:
assert kwarg in kwargs and kwargs[kwarg]
self.request_infos = textworld.EnvInfos(**kwargs)
self.env_ids = defaultdict(list)
for pattern in ["*.ulx", "*.z8"]:
for entry in sorted(glob.glob(os.path.join(textworld_games_path, f"**/{pattern}"), recursive=True)):
task = Path(entry).parent.name
if task in tasks:
env_id = textworld.gym.register_game(entry, self.request_infos, max_episode_steps=max_episode_steps)
self.env_ids[task].append(env_id)
[docs]
def get_textworld_env(self, task, seed=None, **kwargs):
"""
Create and return a TextWorld environment for the specified task.
Args:
task (str): The name of the task for which to create an environment.
seed (int, optional): If provided, creates the environment for the
specific game index. If None, cycles through
available games.
Returns:
gym.Env: A TextWorld gym environment.
Raises:
KeyError: If the specified task is not found in the available tasks.
"""
if task not in self.env_ids:
raise KeyError(f"Task '{task}' not found. Available tasks are: {list(self.env_ids.keys())}")
if seed is not None:
env_id = self.env_ids[task][seed % len(self.env_ids[task])]
else:
self.count[task] += 1
env_id = self.env_ids[task][self.count[task] % len(self.env_ids[task])]
env = textworld.gym.make(env_id, **kwargs)
env = TextWorldWrapper(env, max_steps=self.max_steps)
return env
def __call__(self, task, **kwargs):
return self.get_textworld_env(task, **kwargs)
[docs]
class AlwaysTrue:
def __contains__(self, item):
return True
[docs]
class TextWorldWrapper(gym.Wrapper):
[docs]
def __init__(self, env: gym.Env, max_steps=40):
super().__init__(env)
self.language_action_space = AlwaysTrue()
self.progression = 0.0
self.max_steps = max_steps
@property
def default_action(self):
return "help"
[docs]
def get_text_action(self, action):
return action
[docs]
def textworld_process_obsv(self, textworld_obsv):
return {
"text": {"long_term_context": textworld_obsv, "short_term_context": ""},
"image": None,
}
[docs]
def filter_objective(self, obs, info):
objective = info["objective"]
parts = obs.split(objective)
if len(parts) == 1:
return parts[0].strip()
else:
return parts[-1].strip()
[docs]
def reset(self):
obs, info = self.env.reset()
obs = self.filter_objective(obs, info)
self.progression = 0.0
return self.textworld_process_obsv(obs)
[docs]
def step(self, action):
obs, reward, done, info = self.env.step(action)
obs = self.filter_objective(obs, info)
if done:
self.progression = max(info["score"] / info["max_score"], 1.0 if info["won"] else 0.0)
return self.textworld_process_obsv(obs), reward, done, info
[docs]
def get_stats(self):
return {"progression": self.progression}