Source code for iclbench.dataset

import copy
import pickle
import random
import re
from pathlib import Path


[docs] def natural_sort_key(s): return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(s))]
[docs] def choice_excluding(lst, excluded_element): possible_choices = [item for item in lst if item != excluded_element] return random.choice(possible_choices)
[docs] class InContextDataset:
[docs] def __init__(self, config, env_name, original_cwd) -> None: self.config = config self.env_name = env_name self.original_cwd = original_cwd
[docs] def icl_episodes(self, task): demos_dir = Path(self.original_cwd) / self.config.eval.icl_dataset / self.env_name / task return list(sorted(demos_dir.iterdir(), key=natural_sort_key))
[docs] def check_seed(self, demo_path): return int(demo_path.stem.split("seed_")[1])
[docs] def demo_task(self, task): # use different task - avoid the case where we put the solution into the context if self.env_name == "babaisai": task = choice_excluding(self.config.tasks[f"{self.env_name}_tasks"], task) return task
[docs] def demo_path(self, i, task, demo_config): icl_episodes = self.icl_episodes(task) demo_path = icl_episodes[i % len(icl_episodes)] # use the same role if self.env_name == "nle": from iclbench.environments.nle import Role character = demo_config.envs.nle_kwargs.character if character != "@": for part in character.split("-"): # check if there is specified role if part.lower() in [e.value for e in Role]: # check if we have games played with this role new_demo_paths = [path for path in icl_episodes if part.lower() in path.stem.lower()] if new_demo_paths: demo_path = random.choice(new_demo_paths) # use different seed - avoid the case where we put the solution into the context if self.env_name == "textworld": from iclbench.environments.textworld import global_textworld_context textworld_context = global_textworld_context( tasks=self.config.tasks.textworld_tasks, **self.config.envs.textworld_kwargs ) next_seed = textworld_context.count[task] demo_seed = self.check_seed(demo_path) if next_seed == demo_seed: demo_path = self.icl_episodes(task)[i + 1] return demo_path
[docs] def override_incontext_config(self, demo_config, demo_path): seed = self.check_seed(demo_path) demo_config.envs.env_kwargs.seed = seed if self.env_name == "nle" or self.env_name == "minihack": # dataset was collected with "more" action demo_config.envs[f"{self.env_name}_kwargs"].skip_more = True # TODO: this has to be hardcoded because of the way we've generated the trajectories # keep in mind this won't affect the global config, only the demo config demo_config.envs.nle_kwargs.character = "@" if self.env_name == "crafter": # crafter passes seed in a specific fashion demo_config.envs.crafter_kwargs.seed = seed
[docs] def load_incontext_actions(self, demo_path): with open(demo_path, "rb") as f: data = pickle.load(f) recorded_actions = data["actions"] return recorded_actions