diff --git a/scripts/task_generator_node.py b/scripts/task_generator_node.py index 2fdbd2a..c93fc82 100644 --- a/scripts/task_generator_node.py +++ b/scripts/task_generator_node.py @@ -46,6 +46,7 @@ def __init__(self) -> None: self.task.set_robot_names_param() self.number_of_resets = 0 + self.desired_resets = rospy.get_param("desired_resets", 2) self.srv_start_model_visualization = rospy.ServiceProxy("start_model_visualization", Empty) self.srv_start_model_visualization(EmptyRequest()) @@ -64,7 +65,6 @@ def __init__(self) -> None: ## Timers rospy.Timer(rospy.Duration(0.5), self.check_task_status) - def check_task_status(self, _): if self.task.is_done(): self.reset_task() @@ -72,10 +72,6 @@ def check_task_status(self, _): def reset_task(self): self.start_time = rospy.get_time() - rospy.loginfo("=============") - rospy.loginfo("Task Reseted!") - rospy.loginfo("=============") - self.env_wrapper.before_reset_task() is_end = self.task.reset() @@ -85,6 +81,10 @@ def reset_task(self): self.env_wrapper.after_reset_task() + rospy.loginfo("=============") + rospy.loginfo("Task Reseted!") + rospy.loginfo("=============") + self.number_of_resets += 1 def reset_task_srv_callback(self, req): @@ -95,12 +95,20 @@ def reset_task_srv_callback(self, req): return EmptyResponse() def _send_end_message_on_end(self, is_end): - if not is_end or self.task_mode != TaskMode.SCENARIO: + if ( + (not is_end and self.task_mode == TaskMode.SCENARIO) + or (self.task_mode != TaskMode.SCENARIO and self.number_of_resets < self.desired_resets) + ): return rospy.loginfo("Shutting down. All tasks completed") - self.pub_scenario_finished.publish(Bool(True)) + # Send this message 10 times to make sure it is received + for _ in range(10): + self.pub_scenario_finished.publish(Bool(True)) + + rospy.sleep(0.1) + rospy.signal_shutdown("Finished all episodes of the current scenario") diff --git a/task_generator/constants.py b/task_generator/constants.py index f24ee46..115a831 100644 --- a/task_generator/constants.py +++ b/task_generator/constants.py @@ -5,8 +5,8 @@ class Constants: MAX_RESET_FAIL_TIMES = 3 class ObstacleManager: - DYNAMIC_OBSTACLES = 15 - STATIC_OBSTACLES = 20 + DYNAMIC_OBSTACLES = 0 + STATIC_OBSTACLES = 0 OBSTACLE_MAX_RADIUS = 0.6 diff --git a/task_generator/environments/base_environment.py b/task_generator/environments/base_environment.py index 762afb0..3623c12 100644 --- a/task_generator/environments/base_environment.py +++ b/task_generator/environments/base_environment.py @@ -63,7 +63,7 @@ def move_robot(self, pos, name=None): """ raise NotImplementedError() - def spawn_robot(self): + def spawn_robot(self, complexity=1): """ Spawn a robot in the environment. A position is not specified because the robot is moved at the diff --git a/task_generator/environments/flatland_environment.py b/task_generator/environments/flatland_environment.py index a18a759..3d0f39c 100644 --- a/task_generator/environments/flatland_environment.py +++ b/task_generator/environments/flatland_environment.py @@ -161,7 +161,7 @@ def _spawn_random_obstacle( self._obstacles_amount += 1 - def spawn_robot(self, name, robot_name, namespace_appendix=None): + def spawn_robot(self, name, robot_name, namespace_appendix=None, complexity=1): base_model_path = os.path.join( rospkg.RosPack().get_path("arena-simulation-setup"), "robot", @@ -331,10 +331,10 @@ def _read_yaml(self, yaml_path): with open(yaml_path, "r") as file: return yaml.safe_load(file) - @abstractmethod + @staticmethod def create_obs_name(number): return "obs_" + str(number) - @abstractmethod + @staticmethod def check_yaml_path(path): return os.path.isfile(path) diff --git a/task_generator/manager/robot_manager.py b/task_generator/manager/robot_manager.py index 17f0924..6a66a8f 100644 --- a/task_generator/manager/robot_manager.py +++ b/task_generator/manager/robot_manager.py @@ -38,6 +38,8 @@ def __init__(self, namespace, map_manager, environment, robot_setup): self.robot_setup = robot_setup self.record_data = rospy.get_param('record_data', False)# and rospy.get_param('task_mode', 'scenario') == 'scenario' + self.position = self.start_pos + def set_up_robot(self): if Utils.get_arena_type() == Constants.ArenaType.TRAINING: self.robot_radius = rospy.get_param("robot_radius") @@ -76,7 +78,7 @@ def _robot_name(self): return self.namespace - def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None): + def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None, move_robot=True): """ The manager creates new start and goal position when a task is reset, publishes the goal to @@ -92,7 +94,9 @@ def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None): rospy.set_param(os.path.join(self.namespace, "start"), str(list(self.start_pos))) self.publish_goal(self.goal_pos) - self.move_robot_to_start() + + if move_robot: + self.move_robot_to_start() self.set_is_goal_reached(self.start_pos, self.goal_pos) @@ -103,7 +107,7 @@ def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None): except: pass - return self.start_pos, self.goal_pos + return self.position, self.goal_pos # self.start_pos, self.goal_pos def publish_goal_periodically(self, _): if self.goal_pos != None: @@ -174,6 +178,7 @@ def launch_robot(self, robot_setup): f"model:={robot_setup['model']}", f"local_planner:={robot_setup['planner']}", f"namespace:={self.namespace}", + f"complexity:={rospy.get_param('complexity', 1)}", f"record_data:={self.record_data}", *([f"agent_name:={robot_setup.get('agent')}"] if robot_setup.get('agent') else []) ] @@ -207,8 +212,10 @@ def launch_robot(self, robot_setup): def robot_pos_callback(self, data): current_position = data.pose.pose.position + self.position = [current_position.x, current_position.y] + self.set_is_goal_reached( - [current_position.x, current_position.y], + self.position, self.goal_pos ) diff --git a/task_generator/tasks/random.py b/task_generator/tasks/random.py index 057f8aa..4b66747 100644 --- a/task_generator/tasks/random.py +++ b/task_generator/tasks/random.py @@ -37,7 +37,7 @@ def _reset_robot_and_obstacles( for manager in self.robot_managers: for pos in manager.reset( forbidden_zones=robot_positions - ): + ): robot_positions.append( [ pos[0], diff --git a/task_generator/tasks/staged.py b/task_generator/tasks/staged.py index 287e0b3..41ec602 100644 --- a/task_generator/tasks/staged.py +++ b/task_generator/tasks/staged.py @@ -59,23 +59,27 @@ def __init__( def next_stage(self, _): if self._curr_stage >= len(self._stages): - rospy.loginfo(f"({self.namespace}) INFO: Tried to trigger next stage but already reached last one") + rospy.loginfo( + f"({self.namespace}) INFO: Tried to trigger next stage but already reached last one" + ) return self._curr_stage = self._curr_stage + 1 - return self._init_stage_and_update_hyperparams(self._curr_stage) + return self._init_stage_and_update_config(self._curr_stage) def previous_stage(self, _): if self._curr_stage <= 1: - rospy.loginfo(f"({self.namespace}) INFO: Tried to trigger previous stage but already reached first one") + rospy.loginfo( + f"({self.namespace}) INFO: Tried to trigger previous stage but already reached first one" + ) return self._curr_stage = self._curr_stage - 1 - return self._init_stage_and_update_hyperparams(self._curr_stage) + return self._init_stage_and_update_config(self._curr_stage) - def _init_stage_and_update_hyperparams(self, stage): + def _init_stage_and_update_config(self, stage): self._init_stage(stage) if self.namespace != "eval_sim": @@ -84,7 +88,7 @@ def _init_stage_and_update_hyperparams(self, stage): rospy.set_param("/curr_stage", stage) rospy.set_param("/last_state_reached", stage == len(self._stages)) - self._update_stage_in_hyperparams(stage) + self._update_stage_in_config(stage) return stage @@ -92,37 +96,32 @@ def _init_debug_mode(self, paths): if self._debug_mode: return - self._hyperparams_file_path = os.path.join( - paths.get("model"), "hyperparameters.json" - ) - self._hyperparams_lock = FileLock(f"{self._hyperparams_file_path}.lock") + self._config_file_path = paths["config"] + self._config_lock = FileLock(f"{self._config_file_path}.lock") assert os.path.isfile( - self._hyperparams_file_path - ), f"Found no 'hyperparameters.json' at {self._hyperparams_file_path}" + self._config_file_path + ), f"Found no 'training_config.yaml' at {self._config_file_path}" - def _update_stage_in_hyperparams(self, stage): + def _update_stage_in_config(self, stage): """ - The current stage is stored inside the hyperparams + The current stage is stored inside the config file for when the training is stopped and later continued, the correct stage can be restored. """ if self._debug_mode: return - self._hyperparams_lock.acquire() - - file = open(self._hyperparams_file_path, "r") - - hyperparams = json.load(file) - hyperparams["curr_stage"] = stage + self._config_lock.acquire() - with open(self._hyperparams_file_path, "w", encoding="utf-8") as target: - json.dump(hyperparams, target, ensure_ascii=False, indent=4) + with open(self._config_file_path, "r", encoding="utf-8") as target: + config = yaml.load(target, Loader=yaml.FullLoader) + config["callbacks"]["training_curriculum"]["curr_stage"] = stage - file.close() + with open(self._config_file_path, "w", encoding="utf-8") as target: + yaml.dump(config, target, ensure_ascii=False, indent=4) - self._hyperparams_lock.release() + self._config_lock.release() def reset(self): super().reset( @@ -131,14 +130,14 @@ def reset(self): ) def _reset_robot_and_obstacles( - self, static_obstacles=None, dynamic_obstacles=None, **kwargs - ): + self, static_obstacles=None, dynamic_obstacles=None, **kwargs + ): super()._reset_robot_and_obstacles( - static_obstacles=static_obstacles, + static_obstacles=static_obstacles, dynamic_obstacles=dynamic_obstacles, - **kwargs + **kwargs, ) - + def _init_stage(self, stage): static_obstacles = self._stages[stage]["static"] dynamic_obstacles = self._stages[stage]["dynamic"]