Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Apr 8, 2024
1 parent cfb3347 commit 1c4fb6f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
15 changes: 14 additions & 1 deletion benchmarl/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,20 @@ def _number_conv_outputs(
class Cnn(Model):
"""Convolutional Neural Network (CNN) model.
Args:
The BenchMARL CNN accepts multiple inputs of 2 types:
- images: Tensors of shape (*batch, X,Y,C)
- arrays: Tensors of shape (*batch, F)
The CNN model will check that all image inputs have the same shape (excluding the last dimension)
and cat them along that dimension before processing them with :class:`torchrl.modules.ConvNet`.
It will check that all array inputs have the same shape (excluding the last dimension)
and cat them along that dimension.
It will then cat the arrays and processed images and feed them to the MLP together.
Args:
cnn_num_cells (int or Sequence of int): number of cells of
every layer in between the input and output. If an integer is
provided, every layer will have the same number of cells. If an
Expand Down Expand Up @@ -113,6 +125,7 @@ def __init__(
[self.input_spec[key].shape[-1] for key in self.tensor_in_keys]
)
if self.input_has_agent_dim and not self.output_has_agent_dim:
# In this case the tensor features will be centralized
self.input_features_tensors *= self.n_agents

self.output_features = self.output_leaf_spec.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
Method to implement for the forward pass of the model.
It should read self.in_key, process it and write self.out_key.
It should read self.in_keys, process it and write self.out_key.
Args:
tensordict (TensorDictBase): the input td
Expand Down

0 comments on commit 1c4fb6f

Please sign in to comment.