Source code for torchrl.agents.pg_agent

import numpy as np
import torchrl.utils as U
from torchrl.agents import BaseAgent


# TODO: docstring
[docs]class PGAgent(BaseAgent): """ Policy Gradient Agent, compatible with all PG models. This agent encapsulates a policy_model and optionally a value_model, it defines the steps needed for the training loop (see :meth:`step`), and calculates all the necessary values to train the model(s). Parameters ---------- env: torchrl.envs A torchrl environment. policy_model: torchrl.models Should be a subclass of ``torchrl.models.BasePGModel`` value_model: torchrl.models Should be an instance of ``torchrl.models.ValueModel`` (Default is None) normalize_advantages: bool If True, normalize the advantages per batch. advantage: torchrl.utils.estimators.advantage Class used for calculating the advantages. vtarget: torchrl.utils.estimators.value Class used for calculating the states target values. """ def __init__( self, batcher, *, policy_model, value_model=None, normalize_advantages=True, advantage=U.estimators.advantage.GAE(gamma=0.99, gae_lambda=0.95), vtarget=U.estimators.value.FromAdvantage(), **kwargs ): super().__init__(batcher=batcher, **kwargs) self.normalize_advantages = normalize_advantages self.advantage = advantage self.vtarget = vtarget self._register_model("policy", policy_model) if value_model is not None: self._register_model("value", value_model)
[docs] def step(self): batch = self.generate_batch() self.add_state_value(batch) self.add_advantage(batch) self.add_vtarget(batch) batch = batch.concat_batch() if self.normalize_advantages: batch.advantage = U.normalize(batch.advantage) self.train_models(batch)
def add_state_value(self, batch): if self.models.value is not None: s = batch.state_t_and_tp1 v = self.models.value(s.reshape(-1, *s.shape[2:])) v = U.to_np(v).reshape(s.shape[:2]) batch.state_value_t_and_tp1 = v batch.state_value_t = v[:-1] batch.state_value_tp1 = v[1:] def add_advantage(self, batch): batch.advantage = self.advantage(batch) def add_vtarget(self, batch): batch.vtarget = self.vtarget(batch)
# TODO: Reimplement this method # @classmethod # def from_config(cls, config, env=None, policy_model_class=None, **kwargs): # if env is None: # env = U.env_from_config(config) # # If the policy_model_class is given it should overwrite key from config # if policy_model_class is not None: # config.pop('policy_model_class') # else: # policy_model_class = config.pop('policy_model_class') # policy_nn_config = config.pop('policy_nn_config') # value_nn_config = config.pop('value_nn_config', None) # policy_nn = U.nn_from_config(policy_nn_config, env.get_state_info(), env.get_action_info()) # if value_nn_config is not None: # if value_nn_config.get('body') is None: # print('Policy NN and Value NN are sharing bodies') # value_nn_body = policy_nn.layers[0] # else: # print('Policy NN and Value NN are using different bodies') # value_nn_body = None # value_nn = U.nn_from_config( # value_nn_config, env.get_state_info(), env.get_action_info(), body=value_nn_body) # else: # value_nn = None # return cls( # env=env, # policy_model_class=policy_model_class, # policy_nn=policy_nn, # value_nn=value_nn, # **config.as_dict(), # **kwargs)