Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Dec 1, 2023
1 parent 0e7c77f commit 2b6a788
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions framework/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def _sample_data(self,*args, **kwargs):
self.ep_step += 1
if terminated or truncated or self.ep_step >= self.cfg.max_step:
run_episode += 1
tracker.pub_msg(Msg(MsgType.DATASERVER_INCREASE_EPISODE))
global_episode = tracker.pub_msg(Msg(MsgType.DATASERVER_GET_EPISODE))
tracker.pub_msg(Msg(MsgType.TRACKER_INCREASE_EPISODE))
global_episode = tracker.pub_msg(Msg(MsgType.TRACKER_GET_EPISODE))
if global_episode % self.cfg.interact_summary_fre == 0 and global_episode <= self.cfg.max_episode:
logger.info(f"Interactor {self.id} finished episode {global_episode} with reward {self.ep_reward:.3f} in {self.ep_step} steps")
interact_summary = {'reward':self.ep_reward,'step':self.ep_step}
Expand Down
4 changes: 2 additions & 2 deletions framework/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def run(self, *args, **kwargs):
training_data = collector.pub_msg(Msg(type = MsgType.COLLECTOR_GET_TRAINING_DATA)) # get training data
if training_data is None: return
tracker = kwargs['tracker']
curr_update_step = tracker.pub_msg(Msg(type = MsgType.DATASERVER_GET_UPDATE_STEP))
curr_update_step = tracker.pub_msg(Msg(type = MsgType.TRACKER_GET_UPDATE_STEP))
self.policy.learn(**training_data,update_step = curr_update_step)
tracker.pub_msg(Msg(type = MsgType.DATASERVER_INCREASE_UPDATE_STEP))
tracker.pub_msg(Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP))
# put updated model params to model_mgr
model_params = self.policy.get_model_params()
model_mgr.pub_msg(Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (curr_update_step, model_params)))
Expand Down
10 changes: 5 additions & 5 deletions framework/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
@unique
class MsgType(Enum):
# tracker
DATASERVER_GET_EPISODE = 0
DATASERVER_INCREASE_EPISODE = 1
DATASERVER_INCREASE_UPDATE_STEP = 2
DATASERVER_GET_UPDATE_STEP = 3
DATASERVER_CHECK_TASK_END = 4
TRACKER_GET_EPISODE = 0
TRACKER_INCREASE_EPISODE = 1
TRACKER_INCREASE_UPDATE_STEP = 2
TRACKER_GET_UPDATE_STEP = 3
TRACKER_CHECK_TASK_END = 4

# interactor
INTERACTOR_SAMPLE = 10
Expand Down
2 changes: 1 addition & 1 deletion framework/model_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _save_policy(self):
while not self._saved_policy_queue.empty():
update_step, model_params = self._saved_policy_queue.get()
torch.save(model_params, f"{self.cfg.model_dir}/{update_step}")
global_episode = self.tracker.pub_msg(Msg(type = MsgType.DATASERVER_GET_EPISODE))
global_episode = self.tracker.pub_msg(Msg(type = MsgType.TRACKER_GET_EPISODE))
if global_episode >= self.cfg.max_episode:
break
time.sleep(0.1)
Expand Down
10 changes: 5 additions & 5 deletions framework/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@ def __init__(self,cfg) -> None:

def pub_msg(self, msg: Msg):
msg_type, msg_data = msg.type, msg.data
if msg_type == MsgType.DATASERVER_GET_EPISODE:
if msg_type == MsgType.TRACKER_GET_EPISODE:
return self._get_episode()
elif msg_type == MsgType.DATASERVER_INCREASE_EPISODE:
elif msg_type == MsgType.TRACKER_INCREASE_EPISODE:
episode_delta = 1 if msg_data is None else msg_data
self._increase_episode(i = episode_delta)
# elif msg_type == MsgType.GET_SAMPLE_COUNT:
# self._get_sample_count(msg_data)
elif msg_type == MsgType.DATASERVER_GET_UPDATE_STEP:
elif msg_type == MsgType.TRACKER_GET_UPDATE_STEP:
return self._get_update_step()
# elif msg_type == MsgType.CHECK_TASK_END:
# self._check_task_end(msg_data)
# elif msg_type == MsgType.INCREASE_SAMPLE_COUNT:
# self._increase_sample_count(msg_data)
elif msg_type == MsgType.DATASERVER_INCREASE_UPDATE_STEP:
elif msg_type == MsgType.TRACKER_INCREASE_UPDATE_STEP:
update_step_delta = 1 if msg_data is None else msg_data
self._increase_update_step(i = update_step_delta)

elif msg_type == MsgType.DATASERVER_CHECK_TASK_END:
elif msg_type == MsgType.TRACKER_CHECK_TASK_END:
return self._check_task_end()
else:
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion framework/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(self):
collector = self.collector,
recorder = self.recorder
)
if self.tracker.pub_msg(Msg(type = MsgType.DATASERVER_CHECK_TASK_END)):
if self.tracker.pub_msg(Msg(type = MsgType.TRACKER_CHECK_TASK_END)):
break
e_t = time.time() # end time
self.logger.info(f"Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s") # print info
Empty file added framework/worker_mgr.py
Empty file.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ pygame==2.1.2
glfw==2.5.5
imageio==2.22.4
tensorboard==2.11.2
ray==2.3.0
ray==2.6.3
gymnasium==0.28.1

0 comments on commit 2b6a788

Please sign in to comment.