Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
Make create_one_dim_tr_model recognize subclasses of BasicEnsemble (
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankTianTT authored Jul 26, 2023
1 parent 9d22445 commit 3f93ccc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create_one_dim_tr_model(
# This first part takes care of the case where model is BasicEnsemble and in/out sizes
# are handled by member_cfg
model_cfg = cfg.dynamics_model
if model_cfg._target_ == "mbrl.models.BasicEnsemble":
if issubclass(hydra.utils._locate(model_cfg._target_), mbrl.models.BasicEnsemble):
model_cfg = model_cfg.member_cfg
if model_cfg.get("in_size", None) is None:
model_cfg.in_size = obs_shape[0] + (act_shape[0] if act_shape else 1)
Expand Down
39 changes: 38 additions & 1 deletion tests/core/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,43 @@ def test_create_one_dim_tr_model():
assert dynamics_model.input_normalizer.mean.dtype == dtype


class CustomEnsemble(models.BasicEnsemble):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


def test_create_custom_ensemble_dynamics():
cfg_dict = {
"dynamics_model": {
"_target_": "tests.core.test_common_utils.CustomEnsemble",
"ensemble_size": 5,
"device": "cpu",
"propagation_method": "fixed_model",
"member_cfg": {
"_target_": "mbrl.models.GaussianMLP",
"device": "cpu",
"in_size": "???",
"out_size": "???",
},
},
"algorithm": {
"learned_rewards": True,
"target_is_delta": True,
"normalize": True,
},
"overrides": {},
}
obs_shape = (10,)
act_shape = (1,)

cfg = omegaconf.OmegaConf.create(cfg_dict)
dynamics_model = utils.create_one_dim_tr_model(cfg, obs_shape, act_shape)

assert isinstance(dynamics_model.model, CustomEnsemble)
assert dynamics_model.model.in_size == obs_shape[0] + act_shape[0]
assert dynamics_model.model.out_size == obs_shape[0] + 1


def test_create_replay_buffer():
trial_length = 20
num_trials = 10
Expand Down Expand Up @@ -198,7 +235,7 @@ def __init__(self):
self.traj = 0
self.val = 0

def reset(self, from_zero=False, seed: Optional[int]=None):
def reset(self, from_zero=False, seed: Optional[int] = None):
if from_zero:
self.traj = 0
self.val = 100 * self.traj
Expand Down

0 comments on commit 3f93ccc

Please sign in to comment.