
The agent is the bridge between the model and the environment.
It implements high level functions ready to be used by the user.


class torchrl.agents.BaseAgent(batcher, optimizer, *, gamma=0.99, log_dir='runs')[source]

Bases: abc.ABC

Basic TorchRL agent. Encapsulate an environment and a model.

  • env (torchrl.envs) – A torchrl environment.
  • gamma (float) – Discount factor on future rewards (Default is 0.99).
  • log_dir (string) – Directory where logs will be written (Default is runs).

This method is called at each interaction of the training loop, and defines the training procedure.


Check if the training loop reached the end.

  • bool
  • True if done, False otherwise.
_register_model(name, model)[source]

Save a torchrl model to the internal memory.

  • name (str) – Desired name for the model.
  • model (torchrl.models) – The model to register.
train(*, max_iters=-1, max_episodes=-1, max_steps=-1, log_freq=1, eval_env=None, eval_freq=None)[source]

Defines the training loop of the algorithm, calling step() at every iteration.

  • max_updates (int) – Maximum number of gradient updates (Default is -1, meaning it doesn’t matter).
  • max_episodes (int) – Maximum number of episodes (Default is -1, meaning it doesn’t matter).
  • max_steps (int) – Maximum number of steps (Default is -1, meaning it doesn’t matter).
select_action(state, step)[source]

Receive a state and use the model to select an action.

Parameters:state (numpy.ndarray) – The environment state.
Returns:action – The selected action.
Return type:int or numpy.ndarray

Use the logger to write general information about the training process.


class torchrl.agents.PGAgent(batcher, *, policy_model, value_model=None, normalize_advantages=True, advantage=<torchrl.utils.estimators.advantage.estimators.GAE object>, vtarget=<torchrl.utils.estimators.value.estimators.FromAdvantage object>, **kwargs)[source]

Bases: torchrl.agents.base_agent.BaseAgent

Policy Gradient Agent, compatible with all PG models.

This agent encapsulates a policy_model and optionally a value_model, it defines the steps needed for the training loop (see step()), and calculates all the necessary values to train the model(s).

  • env (torchrl.envs) – A torchrl environment.
  • policy_model (torchrl.models) – Should be a subclass of torchrl.models.BasePGModel
  • value_model (torchrl.models) – Should be an instance of torchrl.models.ValueModel (Default is None)
  • normalize_advantages (bool) – If True, normalize the advantages per batch.
  • advantage (torchrl.utils.estimators.advantage) – Class used for calculating the advantages.
  • vtarget (torchrl.utils.estimators.value) – Class used for calculating the states target values.

This method is called at each interaction of the training loop, and defines the training procedure.