Source code for torchrl.agents.base_agent

import numpy as np
from abc import ABC, abstractmethod

import torchrl.utils as U


# TODO: Docstring
[docs]class BaseAgent(ABC): """ Basic TorchRL agent. Encapsulate an environment and a model. Parameters ---------- env: torchrl.envs A torchrl environment. gamma: float Discount factor on future rewards (Default is 0.99). log_dir: string Directory where logs will be written (Default is `runs`). """ def __init__(self, batcher, optimizer, *, gamma=0.99, log_dir="runs"): self.batcher = batcher self.opt = optimizer self.logger = U.Logger(log_dir) self.gamma = gamma self.num_iters = 1 self.models = U.memories.DefaultMemory() # Can be changed later by the user, None goes to the default (from policy) self.select_action_fn = None self.eval_select_action_fn = None
[docs] @abstractmethod def step(self): """ This method is called at each interaction of the training loop, and defines the training procedure. """
@property def num_steps(self): return self.batcher.num_steps @property def num_episodes(self): return self.batcher.num_episodes
[docs] def _check_termination(self): """ Check if the training loop reached the end. Returns ------- bool True if done, False otherwise. """ if ( self.num_iters // self.max_iters >= 1 or self.num_episodes // self.max_episodes >= 1 or self.num_steps // self.max_steps >= 1 ): return True return False
def _check_evaluation(self, env): if self.eval_freq is not None and self.num_steps >= self.next_eval: action_fn = self.eval_select_action_fn or self.models.policy.select_action action_fn_pre = lambda state: action_fn( model=self.models.policy, state=state, step=self.num_steps ) self.batcher.evaluate( select_action_fn=action_fn_pre, logger=self.logger, env=env ) self.next_eval += self.eval_freq
[docs] def _register_model(self, name, model): """ Save a torchrl model to the internal memory. Parameters ---------- name: str Desired name for the model. model: torchrl.models The model to register. """ setattr(self.models, name, model) model.attach_logger(self.logger)
def train_models(self, batch): # for model in self.models.values(): # model.train(batch_tensor, step=self.num_steps) self.opt.learn_from_batch(batch, step=self.num_steps)
[docs] def train( self, *, max_iters=-1, max_episodes=-1, max_steps=-1, log_freq=1, eval_env=None, eval_freq=None ): """ Defines the training loop of the algorithm, calling :meth:`step` at every iteration. Parameters ---------- max_updates: int Maximum number of gradient updates (Default is -1, meaning it doesn't matter). max_episodes: int Maximum number of episodes (Default is -1, meaning it doesn't matter). max_steps: int Maximum number of steps (Default is -1, meaning it doesn't matter). """ self.max_iters = max_iters self.max_episodes = max_episodes self.max_steps = max_steps self.eval_freq = eval_freq self.next_eval = 0 self.logger.set_log_freq(log_freq=log_freq) while True: self.step() self.write_logs() self.num_iters += 1 self._check_evaluation(env=eval_env) if self._check_termination(): break
[docs] def select_action(self, state, step): """ Receive a state and use the model to select an action. Parameters ---------- state: numpy.ndarray The environment state. Returns ------- action: int or numpy.ndarray The selected action. """ action_fn = self.select_action_fn or self.models.policy.select_action return action_fn(model=self.models.policy, state=state, step=step)
[docs] def write_logs(self): """ Use the logger to write general information about the training process. """ self.batcher.write_logs(logger=self.logger) self.opt.write_logs(logger=self.logger) self.logger.timeit(self.num_steps, max_steps=self.max_steps) # Instead of Update should be Iter? self.logger.log( "Iter {} | Episode {} | Step {}".format( self.num_iters, self.num_episodes, self.num_steps ) )
def generate_batch(self): batch = self.batcher.get_batch(select_action_fn=self.select_action) return batch
# TODO: Reimplement method # @classmethod # def from_config(cls, config, env=None): # ''' # Create an agent from a configuration object. # Returns # ------- # torchrl.agents # A TorchRL agent. # ''' # if env is None: # try: # env = U.get_obj(config.env.obj) # except AttributeError: # raise ValueError('The env must be defined in the config ' # 'or passed as an argument') # model = cls._model.from_config(config.model, env.get_state_info(), env.get_action_info()) # return cls(env, model, **config.agent.as_dict()) # # TODO: Reimplement method # @classmethod # def from_file(cls, file_path, env=None): # ''' # Create an agent from a configuration file. # Parameters # ---------- # file_path: str # Path to the configuration file. # Returns # ------- # torchrl.agents # A TorchRL agent. # ''' # config = U.Config.load(file_path) # return cls.from_config(config, env=env)