from collections import OrderedDict
from torchrl.utils import get_obj, Config
from torchrl.nn import SequentialExtended
[docs]def get_module_list(config, input_shape, action_info):
"""
Receives a config object and creates a list of layers.
Parameters
----------
config: Config
The configuration object that should contain the basic network structure.
input_shape: list
The input dimensions.
action_info: dict
Dict containing information about the environment actions (e.g. shape).
Returns
-------
list of layers
A list containing all the instantiated layers.
"""
module_list = []
for i, obj_config in enumerate(config):
# Calculate the input shape for the first layer
if i == 0:
auto_input_shape(obj_config, input_shape)
# An `Action` layer has the output shape equals to the action shape
if "ActionLinear" in obj_config["func"].__name__:
obj_config["action_info"] = action_info
module_list.append(get_obj(obj_config))
return module_list
[docs]def nn_from_config(config, state_info, action_info, body=None, head=None):
"""
Creates a pytorch model following the instructions of config.
Parameters
----------
config: Config
The configuration object that should contain the basic network structure.
state_info: dict
Dict containing information about the environment states (e.g. shape).
action_info: dict
Dict containing information about the environment actions (e.g. shape).
body: Module
If given use it instead of creating (Default is None).
head: Module
If given use it instead of creating (Default is None).
Returns
-------
torchrl.SequentialExtended
A torchrl NN (basically a pytorch NN with extended functionalities).
"""
if body is None:
body_list = get_module_list(
config=config.body, input_shape=state_info.shape, action_info=action_info
)
body = SequentialExtended(*body_list)
if head is None:
head_list = get_module_list(
config=config.head,
input_shape=body.get_output_shape(state_info.shape),
action_info=action_info,
)
head = SequentialExtended(*head_list)
return SequentialExtended(OrderedDict([("body", body), ("head", head)]))