Skip to content

Commit

Permalink
pass entire experiment to algo
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 20, 2023
1 parent 6dcf639 commit 0b8436c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 78 deletions.
80 changes: 18 additions & 62 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from typing import Any, Dict, Iterable, Optional, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
LazyTensorStorage,
OneHotDiscreteTensorSpec,
Expand All @@ -34,40 +33,22 @@ class Algorithm(ABC):
and all abstract methods should be implemented.
Args:
experiment_config (ExperimentConfig): the configuration dataclass for the experiment
model_config (ModelConfig): the configuration dataclass for the policy
critic_model_config (ModelConfig): the configuration dataclass for the (eventual) critic
observation_spec (CompositeSpec): the observation spec of the task
action_spec (CompositeSpec): the action spec of the task
state_spec (CompositeSpec): the state spec of the task
action_mask_spec (CompositeSpec): the action_mask spec of the task
group_map (Dictionary): the group map of the task
on_policy (bool): whether the algorithm has to be trained on policy
experiment (Experiment): the experiment class
"""

def __init__(
self,
experiment_config: "DictConfig", # noqa: F821
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: Optional[CompositeSpec],
action_mask_spec: Optional[CompositeSpec],
group_map: Dict[str, List[str]],
on_policy: bool,
):
self.device: DEVICE_TYPING = experiment_config.train_device

self.experiment_config = experiment_config
self.model_config = model_config
self.critic_model_config = critic_model_config
self.on_policy = on_policy
self.group_map = group_map
self.observation_spec = observation_spec
self.action_spec = action_spec
self.state_spec = state_spec
self.action_mask_spec = action_mask_spec
def __init__(self, experiment):
self.experiment = experiment

self.device: DEVICE_TYPING = experiment.config.train_device
self.experiment_config = experiment.config
self.model_config = experiment.model_config
self.critic_model_config = experiment.critic_model_config
self.on_policy = experiment.on_policy
self.group_map = experiment.group_map
self.observation_spec = experiment.observation_spec
self.action_spec = experiment.action_spec
self.state_spec = experiment.state_spec
self.action_mask_spec = experiment.action_mask_spec

# Cached values that will be instantiated only once and then remain fixed
self._losses_and_updaters = {}
Expand Down Expand Up @@ -346,43 +327,18 @@ class AlgorithmConfig:
2. implement all abstract methods
"""

def get_algorithm(
self,
experiment_config,
model_config: ModelConfig,
critic_model_config: ModelConfig,
observation_spec: CompositeSpec,
action_spec: CompositeSpec,
state_spec: CompositeSpec,
action_mask_spec: Optional[CompositeSpec],
group_map: Dict[str, List[str]],
) -> Algorithm:
def get_algorithm(self, experiment) -> Algorithm:
"""
Main function to turn the config into the associated algorithm
Args:
experiment_config (ExperimentConfig): the configuration dataclass for the experiment
model_config (ModelConfig): the configuration dataclass for the policy
critic_model_config (ModelConfig): the configuration dataclass for the (eventual) critic
observation_spec (CompositeSpec): the observation spec of the task
action_spec (CompositeSpec): the action spec of the task
state_spec (CompositeSpec): the state spec of the task
action_mask_spec (CompositeSpec): the action_mask spec of the task
group_map (Dictionary): the group map of the task
experiment (Experiment): the experiment class
Returns: the Algorithm
"""
return self.associated_class()(
**self.__dict__, # Passes all the custom config parameters
experiment_config=experiment_config,
model_config=model_config,
critic_model_config=critic_model_config,
observation_spec=observation_spec,
action_spec=action_spec,
state_spec=state_spec,
action_mask_spec=action_mask_spec,
group_map=group_map,
on_policy=self.on_policy(),
experiment=experiment,
)

@staticmethod
Expand Down
16 changes: 1 addition & 15 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import importlib
import os

import time
from collections import OrderedDict
from dataclasses import dataclass, MISSING
Expand All @@ -27,7 +26,6 @@

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task

from benchmarl.experiment.callback import Callback, CallbackNotifier
from benchmarl.experiment.logger import Logger
from benchmarl.models.common import ModelConfig
Expand Down Expand Up @@ -406,16 +404,7 @@ def _setup_task(self):
self.test_env = test_env.to(self.config.sampling_device)

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(
experiment_config=self.config,
model_config=self.model_config,
critic_model_config=self.critic_model_config,
observation_spec=self.observation_spec,
action_spec=self.action_spec,
state_spec=self.state_spec,
action_mask_spec=self.action_mask_spec,
group_map=self.group_map,
)
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
Expand Down Expand Up @@ -493,7 +482,6 @@ def _setup_name(self):
self.name = self.folder_name.name

def _setup_logger(self):

self.logger = Logger(
experiment_name=self.name,
folder_name=str(self.folder_name),
Expand Down Expand Up @@ -528,7 +516,6 @@ def run(self):
raise err

def _collection_loop(self):

pbar = tqdm(
initial=self.n_iters_performed,
total=self.config.get_max_n_iters(self.on_policy),
Expand All @@ -537,7 +524,6 @@ def _collection_loop(self):

# Training/collection iterations
for batch in self.collector:

# Logging collection
collection_time = time.time() - sampling_start
current_frames = batch.numel()
Expand Down
2 changes: 1 addition & 1 deletion examples/extending/algorithm/custom_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(

# In all the class you have access to a lot of extra things like
self.my_custom_method() # Custom methods
_ = self.experiment # Experiment class
_ = self.experiment_config # Experiment config
_ = self.model_config # Policy config
_ = self.critic_model_config # Eventual critic config
Expand Down Expand Up @@ -156,7 +157,6 @@ def _get_policy_for_loss(
def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:

if self.action_mask_spec is not None:
action_mask_key = (group, "action_mask")
else:
Expand Down

0 comments on commit 0b8436c

Please sign in to comment.