Source code for torchrl.models.value_model

import torch
import torch.nn.functional as F

import torchrl.utils as U
from torchrl.models import BaseModel
from torchrl.nn import FlattenLinear


[docs]class ValueModel(BaseModel): """ A standard regression model, can be used to estimate the value of states or Q values. Parameters ---------- clip_range: float Similar to PPOClip, limits the change between the new and old value function. """ def __init__(self, model, batcher, **kwargs): super().__init__(model=model, batcher=batcher, **kwargs) @property def batch_keys(self): return ["state_t", "vtarget"]
[docs] def register_losses(self): self.register_loss(self.mse_loss)
def register_callbacks(self): super().register_callbacks() self.callbacks.register_on_epoch_start(self.add_old_pred) def mse_loss(self, batch): pred = self.forward(batch.state_t).view(-1) loss = F.mse_loss(pred, batch.vtarget) return loss def add_old_pred(self, batch): with torch.no_grad(): batch.old_pred = self.forward(batch.state_t).view(-1)
[docs] def write_logs(self, batch): super().write_logs(batch) self.memory.new_pred = self.forward(batch.state_t) self.add_log( "Old Explained Var", U.explained_var(batch.vtarget, batch.old_pred) ) self.add_log( "New Explained Var", U.explained_var(batch.vtarget, self.memory.new_pred) ) self.add_log("Target_mean", batch.vtarget.mean()) self.add_log("Pred_mean", self.memory.new_pred.mean())
[docs] @staticmethod def output_layer(input_shape, action_info): return FlattenLinear(in_features=input_shape, out_features=1)