-
Notifications
You must be signed in to change notification settings - Fork 4
/
generate_plots.py
51 lines (42 loc) · 1.59 KB
/
generate_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import matplotlib.pyplot as plt
import numpy as np
import json
from scipy.ndimage.filters import gaussian_filter1d
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 22
demo_data = []
for i in range(5):
try:
demo_file = os.path.join(os.getcwd(), 'final_plots', 'demo_eval{}.log'.format(i+1))
data = open(demo_file, "r")
for line in data:
iter_data = json.loads(line)
demo_data.append(iter_data['actor_eval'])
except Exception as e:
print(e)
sac_data = []
for i in range(5):
try:
sac_file = os.path.join(os.getcwd(), 'final_plots', 'sac_eval{}.log'.format(i+1))
data = open(sac_file, "r")
for line in data:
iter_data = json.loads(line)
sac_data.append(iter_data['actor_eval'])
except Exception as e:
print(e)
demo_data = gaussian_filter1d(demo_data, sigma=2)
sac_data = gaussian_filter1d(sac_data, sigma=2)
fig = plt.figure()
ax = fig.gca()
ax.patch.set_edgecolor('black')
ax.patch.set_linewidth('2')
handleD, = plt.plot(np.arange(len(demo_data)), demo_data, label='M-DeMoRL', linewidth=2.0)
handleS, = plt.plot(np.arange(len(sac_data)), sac_data, label='SAC',linewidth=2.0)
plt.xlabel('Environment Steps')
plt.ylabel('Episode Performance')
plt.xlim(0,75)
# plt.title(env_name_dict[env])
lgd = plt.legend([handleD, handleS], ['M-DeMoRL', 'SAC'], loc='center left', bbox_to_anchor=(1, 0.5))
lgd.get_frame().set_linewidth(2.0)
plt.savefig(os.path.join(os.getcwd(), 'final_plots', 'base_comparison.png'), bbox_extra_artists=(lgd,), bbox_inches='tight')