Source code for torchrl.models.ppo_clip_model

import torch
import torchrl.utils as U
from torchrl.models import SurrogatePGModel


[docs]class PPOClipModel(SurrogatePGModel): """ Proximal Policy Optimization as described in https://arxiv.org/pdf/1707.06347.pdf. Parameters ---------- ppo_clip_range: float Clipping value for the probability ratio (Default is 0.2). num_epochs: int How many times to train over the entire dataset (Default is 10). """ def __init__(self, model, batcher, ppo_clip_range=0.2, **kwargs): super().__init__(model=model, batcher=batcher, **kwargs) self.ppo_clip_range_fn = U.make_callable(ppo_clip_range) @property def ppo_clip_range(self): return self.ppo_clip_range_fn(self.num_steps)
[docs] def register_losses(self): self.register_loss(self.ppo_clip_loss) self.register_loss(self.entropy_loss)
[docs] def ppo_clip_loss(self, batch): """ Calculate the PPO Clip loss as described in the paper. Parameters ---------- batch: Batch """ clipped_prob_ratio = self.memory.prob_ratio.clamp( min=1 - self.ppo_clip_range, max=1 + self.ppo_clip_range ) surrogate = self.memory.prob_ratio * batch.advantage clipped_surrogate = clipped_prob_ratio * batch.advantage losses = torch.min(surrogate, clipped_surrogate) assert len(losses.shape) == 1 loss = -losses.mean() return loss
[docs] def write_logs(self, batch): super().write_logs(batch) clip_frac = ( ((1 - self.memory.prob_ratio).abs() > self.ppo_clip_range).float().mean() ) self.add_log("PPO Clip Range", self.ppo_clip_range) self.add_log("PPO Clip Fraction", clip_frac)