Source code for torchrl.models.base_pg_model

from abc import abstractproperty
import torch
from torchrl.distributions import Categorical, Normal
import torchrl.utils as U
from torchrl.models import BaseModel
from torchrl.nn import ActionLinear

[docs]class BasePGModel(BaseModel): """ Base class for all Policy Gradient Models. """ def __init__(self, model, batcher, *, entropy_coef=0, **kwargs): super().__init__(model=model, batcher=batcher, **kwargs) self.entropy_coef_fn = U.make_callable(entropy_coef) @abstractproperty def entropy(self): pass @property def entropy_coef(self): return self.entropy_coef_fn(self.num_steps)
[docs] def entropy_loss(self, batch): """ Adds a entropy cost to the loss function, with the intent of encouraging exploration. Parameters ---------- batch: Batch The batch should contain all the information necessary to compute the gradients. """ loss = -self.entropy * self.entropy_coef return loss
[docs] def create_dist(self, parameters): """ Specify how the policy distributions should be created. The type of the distribution depends on the environment. Parameters ---------- parameters: np.array The parameters are used to create a distribution (continuous or discrete depending on the type of the environment). """ if self.batcher.get_action_info().space == "discrete": logits = parameters return Categorical(logits=logits) elif self.batcher.get_action_info().space == "continuous": means = parameters[..., 0] std_devs = parameters[..., 1].exp() return Normal(loc=means, scale=std_devs) else: raise ValueError( "No distribution is defined for {} actions".format( self.batcher.get_action_info().space ) )
[docs] def write_logs(self, batch): super().write_logs(batch) self.add_log("Entropy", self.entropy) self.add_log("Policy/log_prob", batch.log_prob)
[docs] @staticmethod def output_layer(input_shape, action_info): return ActionLinear(in_features=input_shape, action_info=action_info)
[docs] @staticmethod def select_action(model, state, step): """ Define how the actions are selected, in this case the actions are sampled from a distribution which values are given be a NN. Parameters ---------- state: np.array The state of the environment (can be a batch of states). """ parameters = model.forward(state) dist = model.create_dist(parameters) action = dist.sample() return U.to_np(action)