Agents¶
The agent is the bridge between the model and the environment.
It implements high level functions ready to be used by the user.
BaseAgent¶
-
class
torchrl.agents.
BaseAgent
(batcher, optimizer, *, gamma=0.99, log_dir='runs')[source]¶ Bases:
abc.ABC
Basic TorchRL agent. Encapsulate an environment and a model.
Parameters: - env (torchrl.envs) – A torchrl environment.
- gamma (float) – Discount factor on future rewards (Default is 0.99).
- log_dir (string) – Directory where logs will be written (Default is runs).
-
step
()[source]¶ This method is called at each interaction of the training loop, and defines the training procedure.
-
_check_termination
()[source]¶ Check if the training loop reached the end.
Returns: - bool
- True if done, False otherwise.
-
_register_model
(name, model)[source]¶ Save a torchrl model to the internal memory.
Parameters: - name (str) – Desired name for the model.
- model (torchrl.models) – The model to register.
-
train
(*, max_iters=-1, max_episodes=-1, max_steps=-1, log_freq=1, eval_env=None, eval_freq=None)[source]¶ Defines the training loop of the algorithm, calling
step()
at every iteration.Parameters:
-
select_action
(state, step)[source]¶ Receive a state and use the model to select an action.
Parameters: state (numpy.ndarray) – The environment state. Returns: action – The selected action. Return type: int or numpy.ndarray
PGAgent¶
-
class
torchrl.agents.
PGAgent
(batcher, *, policy_model, value_model=None, normalize_advantages=True, advantage=<torchrl.utils.estimators.advantage.estimators.GAE object>, vtarget=<torchrl.utils.estimators.value.estimators.FromAdvantage object>, **kwargs)[source]¶ Bases:
torchrl.agents.base_agent.BaseAgent
Policy Gradient Agent, compatible with all PG models.
This agent encapsulates a policy_model and optionally a value_model, it defines the steps needed for the training loop (see
step()
), and calculates all the necessary values to train the model(s).Parameters: - env (torchrl.envs) – A torchrl environment.
- policy_model (torchrl.models) – Should be a subclass of
torchrl.models.BasePGModel
- value_model (torchrl.models) – Should be an instance of
torchrl.models.ValueModel
(Default is None) - normalize_advantages (bool) – If True, normalize the advantages per batch.
- advantage (torchrl.utils.estimators.advantage) – Class used for calculating the advantages.
- vtarget (torchrl.utils.estimators.value) – Class used for calculating the states target values.