Source code for torchrl.models.vanilla_pg_model

import torch
import torch.nn.functional as F

import torchrl.utils as U
from torchrl.models import BasePGModel


[docs]class VanillaPGModel(BasePGModel): """ The classical Policy Gradient algorithm. """ @property def batch_keys(self): return ["state_t", "action", "advantage", "log_prob"] @property def entropy(self): return self.memory.dists.entropy().mean()
[docs] def register_losses(self): self.register_loss(self.pg_loss) self.register_loss(self.entropy_loss)
def register_callbacks(self): super().register_callbacks() self.callbacks.register_on_mini_batch_start(self.add_dist) self.callbacks.register_on_train_end(self.add_dist) def add_dist(self, batch): parameters = self.forward(batch.state_t) self.memory.dists = self.create_dist(parameters) batch.log_prob = self.memory.dists.log_prob(batch.action).sum(-1)
[docs] def pg_loss(self, batch): """ Compute loss based on the policy gradient theorem. Parameters ---------- batch: Batch The batch should contain all the information necessary to compute the gradients. """ objective = batch.log_prob * batch.advantage assert len(objective.shape) == 1 loss = -objective.mean() return loss