-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_sac.py
106 lines (94 loc) · 3.06 KB
/
train_sac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_ataripy
import random
import time
import gymnasium as gym
import numpy as np
import hydra
import torch
import torch.optim as optim
from src.models import DoubleQNetwork, Actor, PixelEncoder
from src.sac import sac_train_loop
from src.utils import make_envs
from torch.utils.tensorboard import SummaryWriter
@hydra.main(version_base=None, config_path="cfg", config_name="minigrid_grayscale")
def train(config):
run_name = f"{config.env.id}__{config.meta.exp_name}__{config.meta.seed}__{int(time.time())}"
if config.meta.track:
import wandb
wandb.init(
project=config.wandb_project_name,
entity=config.wandb_entity,
sync_tensorboard=True,
config=vars(config),
name=run_name,
monitor_gym=True,
save_code=True,
)
logger = SummaryWriter(f"runs/{run_name}")
logger.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s"
% ("\n".join([f"|{key}|{value}|" for key, value in vars(config).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(config.meta.seed)
np.random.seed(config.meta.seed)
torch.manual_seed(config.meta.seed)
torch.backends.cudnn.deterministic = config.meta.torch_deterministic
device = torch.device(
"cuda" if torch.cuda.is_available() and config.meta.cuda else "cpu"
)
envs = make_envs(config, run_name)
assert isinstance(
envs.action_space, gym.spaces.Discrete
), "only discrete action space is supported"
obs_shape = envs.observation_space.shape
num_actions = envs.action_space.n
encoder = PixelEncoder(
channels=obs_shape[0],
img_size=config.encoder.img_size,
crop=config.encoder.crop,
num_filters=config.encoder.num_filters,
out_features=config.encoder.out_features,
).to(device)
actor = Actor(
encoder=encoder,
num_actions=num_actions,
num_features=config.actor.hidden_features,
).to(device)
critic = DoubleQNetwork(
encoder=encoder,
num_actions=num_actions,
num_features=config.critic.hidden_features,
).to(device)
critic_target = DoubleQNetwork(
encoder=encoder,
num_actions=num_actions,
num_features=config.critic.hidden_features,
).to(device)
critic_target.load_state_dict(critic.state_dict())
# TRY NOT TO MODIFY: eps=1e-4 increases numerical stability
critic_optimizer = optim.Adam(
list(critic.parameters()), lr=config.optim.q_lr, eps=1e-4
)
actor_optimizer = optim.Adam(
list(actor.parameters()), lr=config.optim.policy_lr, eps=1e-4
)
results = sac_train_loop(
device,
config,
envs,
actor,
critic,
critic_target,
actor_optimizer,
critic_optimizer,
logger,
)
score = sum(results) / len(results)
envs.close()
logger.close()
print(f"SCORE: {score}")
return score
if __name__ == "__main__":
train()