Source code for torchrl.envs.gym_env

import gym
from torchrl.envs.base_env import BaseEnv
import torchrl.utils as U


[docs]class GymEnv(BaseEnv): """ Creates and wraps a gym environment. Parameters ---------- env_name: str The Gym ID of the env. For a list of available envs check `this <https://gym.openai.com/envs/>`_ page. wrappers: list List of wrappers to be applied on the env. Each wrapper should be a function that receives and returns the env. """ def __init__(self, env_name, **kwargs): super().__init__(env_name, **kwargs) def _create_env(self): env = gym.make(self.env_name) return env @property def simulator(self): return GymEnv
[docs] def reset(self): """ Calls the reset method on the gym environment. Returns ------- state: numpy.ndarray A numpy array with the state information. """ return self.env.reset()
[docs] def step(self, action): """ Calls the step method on the gym environment. Parameters ---------- action: int or float or numpy.ndarray The action to be executed in the environment, it should be an int for discrete enviroments and float for continuous. There's also the possibility of executing multiple actions (if the environment supports so), in this case it should be a numpy.ndarray. Returns ------- next_state: numpy.ndarray A numpy array with the state information. reward: float The reward. done: bool Flag indicating the termination of the episode. """ if self.get_action_info().space == "discrete": action = int(action) next_state, reward, done, info = self.env.step(action) return next_state, reward, done, info
def record(self, path): self.env = Monitor(env=self.env, directory=path, video_callable=lambda x: True)
[docs] def get_state_info(self): """ Dictionary containing the shape and type of the state space. If it is continuous, also contains the minimum and maximum value. """ return GymEnv.get_space_info(self.env.observation_space)
[docs] def get_action_info(self): """ Dictionary containing the shape and type of the action space. If it is continuous, also contains the minimum and maximum value. """ return GymEnv.get_space_info(self.env.action_space)
def sample_random_action(self): return self.env.action_space.sample() def seed(self, value): self.env.seed(value)
[docs] def update_config(self, config): """ Updates a Config object to include information about the environment. Parameters ---------- config: Config Object used for storing configuration. """ super().update_config(config) config.env.obj.update(dict(wrappers=self.wrappers))
def close(self): self.env.close()
[docs] @staticmethod def get_space_info(space): """ Gets the shape of the possible types of states in gym. Parameters ---------- space: gym.spaces Space object that describes the valid actions and observations Returns ------- dict Dictionary containing the space shape and type """ if isinstance(space, gym.spaces.Box): return U.memories.SimpleMemory( shape=space.shape, low_bound=space.low, high_bound=space.high, space="continuous", dtype=space.dtype, ) if isinstance(space, gym.spaces.Discrete): return U.memories.SimpleMemory( shape=space.n, space="discrete", dtype=space.dtype ) if isinstance(space, gym.spaces.MultiDiscrete): return U.memories.SimpleMemory( shape=space.shape, space="multi_discrete", dtype=space.dtype )