Source code for iclbench.environments.craftax.env

from collections import defaultdict

import craftax
import gym
import jax
import jax.numpy as jnp
import numpy as np
from craftax.craftax.constants import BLOCK_PIXEL_SIZE_HUMAN, INVENTORY_OBS_HEIGHT, OBS_DIM
from craftax.craftax.renderer import render_craftax_pixels, render_craftax_text
from craftax.craftax_env import make_craftax_env_from_name
from PIL import Image

from iclbench.environments import Strings

USEFUL_ACTION = [
    "noop",
    "up",
    "right",
    "down",
    "left",
    "do",
    "make_wood_pickaxe",
    "make_stone_pickaxe",
    "make_iron_pickaxe",
    "make_diamond_pickaxe",
    "make_wood_sword",
    "make_stone_sword",
    "make_iron_sword",
    "make_diamond_sword",
    "place_table",
    "sleep",
    "place_stone",
    "place_furnace",
    "place_plant",
    "rest",
    "ascend",
    "descend",
    "make_iron_armour",
    "make_diamond_armour",
    "shoot_arrow",
    "make_arrow",
    "cast_fireball",
    "cast_iceball",
    "place_torch",
    "drink_potion_red",
    "drink_potion_green",
    "drink_potion_blue",
    "drink_potion_pink",
    "drink_potion_cyan",
    "drink_potion_yellow",
    "read_book",
    "enchant_sword",
    "enchant_armour",
    "make_torch",
    "level_up_dexterity",
    "level_up_strength",
    "level_up_intelligence",
    "enchant_bow",
]


[docs] class CraftaxLanguageWrapper(gym.Env):
[docs] def __init__(self, env_id: str = "Craftax-Symbolic-v1", seed=None): super(CraftaxLanguageWrapper, self).__init__() @jax.jit def render_state(env_state): image = render_craftax_pixels(env_state, block_pixel_size=BLOCK_PIXEL_SIZE_HUMAN) return jnp.round(image).astype(jnp.uint8) env = make_craftax_env_from_name(env_id, auto_reset=True) self._step = jax.jit(env.step) self._reset = jax.jit(env.reset) self._render = render_state self._env_params = env.default_params if seed is None: seed = np.random.randint(2**31) self._rng = jax.random.PRNGKey(seed) self._env_state = None self.language_action_space = Strings(USEFUL_ACTION)
@property def default_action(self): return "noop"
[docs] def get_text_action(self, action): raise NotImplementedError
[docs] def reset(self): # Reset the state of the environment to an initial state self._rng, _rng = jax.random.split(self._rng) obs, self._env_state = self._reset(_rng, self._env_params) obs = {"obs": obs, "text": (render_craftax_text(self._env_state), "")} obs_dict = defaultdict(lambda: None) obs_dict["text"] = { "long_term_context": "", "short_term_context": "", } obs_dict["image"] = Image.fromarray(self.render()).convert("RGB") return obs_dict
[docs] def step(self, language_action): if language_action not in self.language_action_space: raise ValueError(f"Action {repr(language_action)} not recognized / supported by this environment.") action = jnp.array(self.language_action_space.map(language_action)) self._rng, _rng = jax.random.split(self._rng) obs, self._env_state, reward, done, info = self._step(_rng, self._env_state, action, self._env_params) # To decide whether craftax has long and short term context observations obs_dict = defaultdict(lambda: None) obs_dict["text"] = { "long_term_context": "", "short_term_context": "", } obs_dict["image"] = Image.fromarray(self.render()).convert("RGB") return obs_dict, reward.item(), done, info
[docs] def render(self, mode="human"): return np.array(self._render(self._env_state))
[docs] def get_stats(self): # TODO: convert to string list rather than bool list achievements = list(map(int, np.array(self._env_state.achievements))) return {"achievements": achievements}