From 2b6a788a43c6515392508dca0328d49aedcaa9db Mon Sep 17 00:00:00 2001 From: johnjim0816 Date: Fri, 1 Dec 2023 23:17:17 +0800 Subject: [PATCH] update --- framework/interactor.py | 4 ++-- framework/learner.py | 4 ++-- framework/message.py | 10 +++++----- framework/model_mgr.py | 2 +- framework/tracker.py | 10 +++++----- framework/trainer.py | 2 +- framework/worker_mgr.py | 0 requirements.txt | 2 +- 8 files changed, 17 insertions(+), 17 deletions(-) create mode 100644 framework/worker_mgr.py diff --git a/framework/interactor.py b/framework/interactor.py index d82c331..c291c17 100644 --- a/framework/interactor.py +++ b/framework/interactor.py @@ -48,8 +48,8 @@ def _sample_data(self,*args, **kwargs): self.ep_step += 1 if terminated or truncated or self.ep_step >= self.cfg.max_step: run_episode += 1 - tracker.pub_msg(Msg(MsgType.DATASERVER_INCREASE_EPISODE)) - global_episode = tracker.pub_msg(Msg(MsgType.DATASERVER_GET_EPISODE)) + tracker.pub_msg(Msg(MsgType.TRACKER_INCREASE_EPISODE)) + global_episode = tracker.pub_msg(Msg(MsgType.TRACKER_GET_EPISODE)) if global_episode % self.cfg.interact_summary_fre == 0 and global_episode <= self.cfg.max_episode: logger.info(f"Interactor {self.id} finished episode {global_episode} with reward {self.ep_reward:.3f} in {self.ep_step} steps") interact_summary = {'reward':self.ep_reward,'step':self.ep_step} diff --git a/framework/learner.py b/framework/learner.py index 1161d3a..1488888 100644 --- a/framework/learner.py +++ b/framework/learner.py @@ -62,9 +62,9 @@ def run(self, *args, **kwargs): training_data = collector.pub_msg(Msg(type = MsgType.COLLECTOR_GET_TRAINING_DATA)) # get training data if training_data is None: return tracker = kwargs['tracker'] - curr_update_step = tracker.pub_msg(Msg(type = MsgType.DATASERVER_GET_UPDATE_STEP)) + curr_update_step = tracker.pub_msg(Msg(type = MsgType.TRACKER_GET_UPDATE_STEP)) self.policy.learn(**training_data,update_step = curr_update_step) - tracker.pub_msg(Msg(type = MsgType.DATASERVER_INCREASE_UPDATE_STEP)) + tracker.pub_msg(Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP)) # put updated model params to model_mgr model_params = self.policy.get_model_params() model_mgr.pub_msg(Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (curr_update_step, model_params))) diff --git a/framework/message.py b/framework/message.py index 3561fc5..d6e8d36 100644 --- a/framework/message.py +++ b/framework/message.py @@ -5,11 +5,11 @@ @unique class MsgType(Enum): # tracker - DATASERVER_GET_EPISODE = 0 - DATASERVER_INCREASE_EPISODE = 1 - DATASERVER_INCREASE_UPDATE_STEP = 2 - DATASERVER_GET_UPDATE_STEP = 3 - DATASERVER_CHECK_TASK_END = 4 + TRACKER_GET_EPISODE = 0 + TRACKER_INCREASE_EPISODE = 1 + TRACKER_INCREASE_UPDATE_STEP = 2 + TRACKER_GET_UPDATE_STEP = 3 + TRACKER_CHECK_TASK_END = 4 # interactor INTERACTOR_SAMPLE = 10 diff --git a/framework/model_mgr.py b/framework/model_mgr.py index 36b3d76..a2ba0b7 100644 --- a/framework/model_mgr.py +++ b/framework/model_mgr.py @@ -58,7 +58,7 @@ def _save_policy(self): while not self._saved_policy_queue.empty(): update_step, model_params = self._saved_policy_queue.get() torch.save(model_params, f"{self.cfg.model_dir}/{update_step}") - global_episode = self.tracker.pub_msg(Msg(type = MsgType.DATASERVER_GET_EPISODE)) + global_episode = self.tracker.pub_msg(Msg(type = MsgType.TRACKER_GET_EPISODE)) if global_episode >= self.cfg.max_episode: break time.sleep(0.1) diff --git a/framework/tracker.py b/framework/tracker.py index 063b365..df95948 100644 --- a/framework/tracker.py +++ b/framework/tracker.py @@ -20,24 +20,24 @@ def __init__(self,cfg) -> None: def pub_msg(self, msg: Msg): msg_type, msg_data = msg.type, msg.data - if msg_type == MsgType.DATASERVER_GET_EPISODE: + if msg_type == MsgType.TRACKER_GET_EPISODE: return self._get_episode() - elif msg_type == MsgType.DATASERVER_INCREASE_EPISODE: + elif msg_type == MsgType.TRACKER_INCREASE_EPISODE: episode_delta = 1 if msg_data is None else msg_data self._increase_episode(i = episode_delta) # elif msg_type == MsgType.GET_SAMPLE_COUNT: # self._get_sample_count(msg_data) - elif msg_type == MsgType.DATASERVER_GET_UPDATE_STEP: + elif msg_type == MsgType.TRACKER_GET_UPDATE_STEP: return self._get_update_step() # elif msg_type == MsgType.CHECK_TASK_END: # self._check_task_end(msg_data) # elif msg_type == MsgType.INCREASE_SAMPLE_COUNT: # self._increase_sample_count(msg_data) - elif msg_type == MsgType.DATASERVER_INCREASE_UPDATE_STEP: + elif msg_type == MsgType.TRACKER_INCREASE_UPDATE_STEP: update_step_delta = 1 if msg_data is None else msg_data self._increase_update_step(i = update_step_delta) - elif msg_type == MsgType.DATASERVER_CHECK_TASK_END: + elif msg_type == MsgType.TRACKER_CHECK_TASK_END: return self._check_task_end() else: raise NotImplementedError diff --git a/framework/trainer.py b/framework/trainer.py index f2e97a6..abde145 100644 --- a/framework/trainer.py +++ b/framework/trainer.py @@ -50,7 +50,7 @@ def run(self): collector = self.collector, recorder = self.recorder ) - if self.tracker.pub_msg(Msg(type = MsgType.DATASERVER_CHECK_TASK_END)): + if self.tracker.pub_msg(Msg(type = MsgType.TRACKER_CHECK_TASK_END)): break e_t = time.time() # end time self.logger.info(f"Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s") # print info diff --git a/framework/worker_mgr.py b/framework/worker_mgr.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index 6dc8690..39cc524 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,5 +12,5 @@ pygame==2.1.2 glfw==2.5.5 imageio==2.22.4 tensorboard==2.11.2 -ray==2.3.0 +ray==2.6.3 gymnasium==0.28.1 \ No newline at end of file