diff --git a/enjoy.py b/enjoy.py index bc6b3d4..19b95f0 100644 --- a/enjoy.py +++ b/enjoy.py @@ -71,7 +71,10 @@ should_render=not args.no_render, hyperparams=hyperparams) -model = ALGOS[algo].load(model_path, env=env) +# ACER raises errors because the environment passed must have +# the same number of environments as the model was trained on. +load_env = None if algo == 'acer' else env +model = ALGOS[algo].load(model_path, env=load_env) obs = env.reset()