-
Notifications
You must be signed in to change notification settings - Fork 3
/
regression.yaml
81 lines (73 loc) · 1.28 KB
/
regression.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
model: ${leo_gn}
enet: ${enet_rff}
n_update: 5000
n_inner_update: 1
seed: 0
n_hidden: 128
n_samples: 100
clip_norm: True
lr: 1e-4
n_disp_step: 200
show_plot: True
save_plot: False
save_model: False
enet_basic:
class: regression.EnergyNetBasic
tag: enet_basic
params:
n_hidden: 128
enet_rff:
class: regression.EnergyNetRFF
tag: enet_rff
params:
n_hidden: 128
sigma: 1.
encoded_size: 128
unroll_gd:
class: regression.UnrollGD
tag: unroll_gd
params:
n_inner_iter: 10
inner_lr: 1e-3
init_scheme: zero # zero, gt
unroll_gn:
class: regression.UnrollGN
tag: unroll_gn
params:
n_inner_iter: 10
inner_lr: 1 # standard GN is 1.0
init_scheme: zero # zero, gt
leo_gn:
class: regression.LEOGN
tag: leo_gn
params:
n_sample: 100
temp: 1e9
min_cov: 1e-3
max_cov: 10.0
n_inner_iter: 10
init_scheme: zero # zero, gt
dcem:
class: regression.UnrollCEM
tag: dcem
params:
n_sample: 100
n_elite: 10
n_iter: 10
init_sigma: 1.
temp: 1.
normalize: True
leo_cem:
class: regression.LEOCEM
tag: leo_cem
params:
n_sample: 100
temp: 1.
min_cov: 1e-3
max_cov: 10.0
cem_n_sample: 100
cem_n_elite: 10
cem_n_iter: 10
cem_init_sigma: 7.
cem_temp: 1.
cem_normalize: True