torchrl.utils¶
Config¶
Configuration object used by other modules. Can be saved and imported as a YAML file
-
class
torchrl.utils.config.
Config
(*args, **kwargs)[source]¶ Bases:
object
Configuration object used for initializing an Agent. It maintains the order from which the attributes have been set.
Parameters: configs (Keyword arguments) – Additional parameters that will be stored. Returns: An object containing all configuration details (with possibly nested Config). Return type: Config object -
as_dict
()[source]¶ Returns all object attributes as a nested OrderedDict.
Returns: Nested OrderedDict containing all object attributes. Return type: dict
-
new_section
(name, **configs)[source]¶ Creates a new Config object and add as an attribute of this instance.
Parameters: - name (str) – Name of the new section.
- configs (Keyword arguments) – Parameters that will be stored in this section, accepts nested parameters.
Examples
Simple use case:
config.new_section('new_section_name', attr1=value1, attr2=value2, ...)
Nested parameters:
config.new_section('new_section_name', attr1=Config(attr1=value1, attr2=value2))
It’s possible to access the variable like so:
config.new_section_name.attr1
-
save
(file_path)[source]¶ Saves current configuration to a JSON file. The configuration is stored as a nested dictionary (maintaining the order).
Parameters: file_path (str) – Path to write the file
-
Memories¶
SimpleMemory¶
DefaultMemory¶
-
class
torchrl.utils.memories.
DefaultMemory
(*args, **kwargs)[source]¶ Bases:
collections.defaultdict
A defaultdict whose keys can be accessed as attributes.
Logger¶
-
class
torchrl.utils.logger.
Logger
(log_dir=None, *, debug=False, log_freq=1)[source]¶ Common logger used by all agents, aggregates values and print a nice table.
Parameters: log_dir (str) – Path to write logs file. -
add_log
(name, value, precision=2)[source]¶ Register a value to a name, this function can be called multiple times and the values will be averaged when logging.
Parameters:
-
add_tf_only_log
(name, value, precision=2)[source]¶ Register a value to a name, this function can be called multiple times and the values will be averaged when logging. Will not display the logs on the console but just write on the file.
Parameters:
-
add_histogram
(name, values)[source]¶ Register a histogram that can be seen at tensorboard.
Parameters: - name (str) – Name displayed when printing the table.
- value (torch.Tensor) – Value to log.
-
Net Builder¶
auto_input_shape¶
get_module_list¶
nn_from_config¶
-
torchrl.utils.net_builder.
nn_from_config
(config, state_info, action_info, body=None, head=None)[source]¶ Creates a pytorch model following the instructions of config.
Parameters: - config (Config) – The configuration object that should contain the basic network structure.
- state_info (dict) – Dict containing information about the environment states (e.g. shape).
- action_info (dict) – Dict containing information about the environment actions (e.g. shape).
- body (Module) – If given use it instead of creating (Default is None).
- head (Module) – If given use it instead of creating (Default is None).
Returns: A torchrl NN (basically a pytorch NN with extended functionalities).
Return type: torchrl.SequentialExtended
Utils¶
get_obj¶
env_from_config¶
-
torchrl.utils.utils.
env_from_config
(config)[source]¶ Tries to create an environment from a configuration obj.
Parameters: config (Config) – Configuration file containing the environment function. Returns: env – A torchrl environment. Return type: torchrl.envs Raises: AttributeError
– If no env is defined in the config obj.
join_transitions¶
explained_var¶
-
torchrl.utils.utils.
explained_var
(target, preds)[source]¶ Calculates the explained variance between two datasets. Useful for estimating the quality of the value function
Parameters: - target (np.array) – Target dataset.
- preds (np.array) – Predictions array.
Returns: The explained variance.
Return type: