Skip to content

Commit

Permalink
garden time steps param
Browse files Browse the repository at this point in the history
  • Loading branch information
w07wong committed Dec 6, 2019
1 parent 03e26c7 commit 7953c71
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 9 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21 changes: 12 additions & 9 deletions RL_Framework/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Pipeline:
def __init__(self):
pass

def create_config(self, rl_time_steps=3000000, garden_time_steps=40, garden_x=10, garden_y=10, num_plant_types=2, num_plants_per_type=1, step=2, action_low=0.0, action_high=1.0, obs_low=0, obs_high=1000, ent_coef=0.01, n_steps=40000, nminibatches=4, noptepochs=4, learning_rate=1e-8, cnn_args=None):
def create_config(self, rl_time_steps=3000000, garden_time_steps=50, garden_x=10, garden_y=10, num_plant_types=2, num_plants_per_type=1, step=1, action_low=0.0, action_high=1.0, obs_low=0, obs_high=1000, ent_coef=0.01, n_steps=40000, nminibatches=4, noptepochs=4, learning_rate=1e-8, cnn_args=None):
config = configparser.ConfigParser()
config.add_section('rl')
config['rl']['time_steps'] = str(rl_time_steps)
Expand Down Expand Up @@ -113,19 +113,19 @@ def plot_average_reward(self, folder_path, reward, days, x_range, y_range, ticks
pathlib.Path(folder_path + '/Graphs').mkdir(parents=True, exist_ok=True)
plt.savefig('./' + folder_path + '/Graphs/avg_reward.png')

def plot_stddev_reward(self, folder_path, reward, reward_stddev, days, x_range, y_range, ticks):
def plot_stddev_reward(self, folder_path, garden_time_steps, reward, reward_stddev, days, x_range, y_range, ticks):
fig = plt.figure(figsize=(28, 10))
plt.xticks(np.arange(0, days, 10))
plt.yticks(np.arange(x_range, y_range, ticks))
plt.title('Std Dev of Reward Over ' + str(days) + ' Days', fontsize=18)
plt.xlabel('Day', fontsize=16)
plt.ylabel('Reward', fontsize=16)

plt.errorbar([i for i in range(40)], reward, reward_stddev, linestyle='None', marker='o', color='g')
plt.errorbar([i for i in range(garden_time_steps)], reward, reward_stddev, linestyle='None', marker='o', color='g')
pathlib.Path(folder_path + '/Graphs').mkdir(parents=True, exist_ok=True)
plt.savefig('./' + folder_path + '/Graphs/std_reward.png')

def graph_evaluations(self, folder_path, garden_x, garden_y, time_steps, step, num_evals, num_plant_types):
def graph_evaluations(self, folder_path, garden_time_steps, garden_x, garden_y, time_steps, step, num_evals, num_plant_types):
obs = [0] * time_steps
r = [0] * time_steps
for i in range(num_evals):
Expand All @@ -152,7 +152,7 @@ def graph_evaluations(self, folder_path, garden_x, garden_y, time_steps, step, n
min_r = min(r) - 10
max_r = max(r) + 10
self.plot_average_reward(folder_path, r, time_steps, min_r, max_r, abs(min_r - max_r) / 10)
self.plot_stddev_reward(folder_path, rewards, rewards_stddev, time_steps, min_r, max_r, abs(min_r - max_r) / 10)
self.plot_stddev_reward(folder_path, garden_time_steps, rewards, rewards_stddev, time_steps, min_r, max_r, abs(min_r - max_r) / 10)

def evaluate_policy(self, folder_path, num_evals, env, is_baseline=False, baseline_policy=None, step=1):
model = None
Expand Down Expand Up @@ -181,6 +181,8 @@ def evaluate_policy(self, folder_path, num_evals, env, is_baseline=False, baseli
env.render()
done = False

# env.env_method('show_animation')

pathlib.Path(folder_path + '/Returns').mkdir(parents=True, exist_ok=True)
filename = folder_path + '/Returns' + '/predict_' + str(i) + '.json'
f = open(filename, 'w')
Expand All @@ -205,6 +207,7 @@ def single_run(self, folder_path, num_evals, policy_kwargs=None, is_baseline=Fal
step = config.getint('garden', 'step')
num_plants_per_type = config.getint('garden', 'num_plants_per_type')
num_plant_types = config.getint('garden', 'num_plant_types')
garden_time_steps = config.getint('garden', 'time_steps')
garden_x = config.getint('garden', 'X')
garden_y = config.getint('garden', 'Y')
# Z axis contains a matrix for every plant type plus one for water levels.
Expand Down Expand Up @@ -234,7 +237,7 @@ def single_run(self, folder_path, num_evals, policy_kwargs=None, is_baseline=Fal
self.evaluate_policy(folder_path=folder_path, num_evals=num_evals, env=env, is_baseline=True, baseline_policy=baseline_policy, step=1)

# Graph evaluations
self.graph_evaluations(folder_path, garden_x, garden_y, time_steps, step, num_evals, num_plant_types)
self.graph_evaluations(folder_path, garden_time_steps, garden_x, garden_y, time_steps, step, num_evals, num_plant_types)
else:
pathlib.Path(folder_path + '/ppo_v2_tensorboard').mkdir(parents=True, exist_ok=True)
# Instantiate the agent
Expand All @@ -252,7 +255,7 @@ def single_run(self, folder_path, num_evals, policy_kwargs=None, is_baseline=Fal
self.evaluate_policy(folder_path=folder_path, num_evals=num_evals, env=env, is_baseline=False)

# Graph evaluations
self.graph_evaluations(folder_path, garden_x, garden_y, time_steps, step, num_evals, num_plant_types)
self.graph_evaluations(folder_path, garden_time_steps, garden_x, garden_y, time_steps, step, num_evals, num_plant_types)

profiler_object.disable()

Expand Down Expand Up @@ -296,7 +299,7 @@ def createBaselineSingleRunFolder(self, garden_x, garden_y, num_plant_types, num
filename_time = str(filename_time)


def batch_run(self, n, rl_config, garden_x, garden_y, num_plant_types, num_plants_per_type, policy_kwargs=[], num_evals=50, is_baseline=[], baseline_policy=None):
def batch_run(self, n, rl_config, garden_x, garden_y, num_plant_types, num_plants_per_type, policy_kwargs=[], num_evals=1, is_baseline=[], baseline_policy=None):
assert(len(rl_config) == n)
assert(len(garden_x) == n)
assert(len(garden_y) == n)
Expand Down Expand Up @@ -336,7 +339,7 @@ def batch_run(self, n, rl_config, garden_x, garden_y, num_plant_types, num_plant
rl_config = [
{
'rl_algorithm': 'MLP',
'time_steps': 7000000,
'time_steps': 200,
'ent_coef': 0.0,
'n_steps': 40000,
'nminibatches': 4,
Expand Down

0 comments on commit 7953c71

Please sign in to comment.