diff --git a/rl_baselines/rl_algorithm/sac.py b/rl_baselines/rl_algorithm/sac.py index 04f221e7e..0577d10f5 100644 --- a/rl_baselines/rl_algorithm/sac.py +++ b/rl_baselines/rl_algorithm/sac.py @@ -1,10 +1,9 @@ -import pickle import os +import pickle -import numpy as np from stable_baselines import SAC -from stable_baselines.sac.policies import MlpPolicy, CnnPolicy from stable_baselines.common.vec_env import VecNormalize, DummyVecEnv +from stable_baselines.sac.policies import MlpPolicy, CnnPolicy from environments.utils import makeEnv from rl_baselines.base_classes import StableBaselinesRLObject diff --git a/rl_baselines/train.py b/rl_baselines/train.py index 31bf66ab2..fa910c85f 100644 --- a/rl_baselines/train.py +++ b/rl_baselines/train.py @@ -61,8 +61,10 @@ def latestPath(path): :param path: path to the log folder (defined in srl_model.yaml) (str) :return: path to latest learned model in the same dataset folder (str) """ - return max([path + "/" + d for d in os.listdir(path) if not d.startswith('baselines') and os.path.isdir(path + "/" + d)], - key=os.path.getmtime) + '/srl_model.pth' + return max( + [path + "/" + d for d in os.listdir(path) if not d.startswith('baselines') and os.path.isdir(path + "/" + d)], + key=os.path.getmtime) + '/srl_model.pth' + def configureEnvAndLogFolder(args, env_kwargs, all_models): """ @@ -199,7 +201,8 @@ def main(): parser.add_argument('--srl-config-file', type=str, default="config/srl_models.yaml", help='Set the location of the SRL model path configuration.') parser.add_argument('--hyperparam', type=str, nargs='+', default=[]) - parser.add_argument('--min-episodes-save', type=int, default=100, help="Min number of episodes before saving best model") + parser.add_argument('--min-episodes-save', type=int, default=100, + help="Min number of episodes before saving best model") parser.add_argument('--latest', action='store_true', default=False, help='load the latest learned model (location:srl_zoo/logs/DatasetName/)') diff --git a/rl_baselines/utils.py b/rl_baselines/utils.py index 6f0c47bb1..336214a79 100644 --- a/rl_baselines/utils.py +++ b/rl_baselines/utils.py @@ -140,7 +140,6 @@ def get_original_obs(self): """ return self.venv.get_original_obs() - def saveRunningAverage(self, path): """ Hack to use VecNormalize