Skip to content

Commit

Permalink
Merge pull request #16 from Arena-Rosnav/dev
Browse files Browse the repository at this point in the history
deploy
  • Loading branch information
ReykCS committed Jan 29, 2023
2 parents 8663ce7 + f8aa8f1 commit 4d3250f
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 47 deletions.
22 changes: 15 additions & 7 deletions scripts/task_generator_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self) -> None:
self.task.set_robot_names_param()

self.number_of_resets = 0
self.desired_resets = rospy.get_param("desired_resets", 2)

self.srv_start_model_visualization = rospy.ServiceProxy("start_model_visualization", Empty)
self.srv_start_model_visualization(EmptyRequest())
Expand All @@ -64,18 +65,13 @@ def __init__(self) -> None:
## Timers
rospy.Timer(rospy.Duration(0.5), self.check_task_status)


def check_task_status(self, _):
if self.task.is_done():
self.reset_task()

def reset_task(self):
self.start_time = rospy.get_time()

rospy.loginfo("=============")
rospy.loginfo("Task Reseted!")
rospy.loginfo("=============")

self.env_wrapper.before_reset_task()

is_end = self.task.reset()
Expand All @@ -85,6 +81,10 @@ def reset_task(self):

self.env_wrapper.after_reset_task()

rospy.loginfo("=============")
rospy.loginfo("Task Reseted!")
rospy.loginfo("=============")

self.number_of_resets += 1

def reset_task_srv_callback(self, req):
Expand All @@ -95,12 +95,20 @@ def reset_task_srv_callback(self, req):
return EmptyResponse()

def _send_end_message_on_end(self, is_end):
if not is_end or self.task_mode != TaskMode.SCENARIO:
if (
(not is_end and self.task_mode == TaskMode.SCENARIO)
or (self.task_mode != TaskMode.SCENARIO and self.number_of_resets < self.desired_resets)
):
return

rospy.loginfo("Shutting down. All tasks completed")

self.pub_scenario_finished.publish(Bool(True))
# Send this message 10 times to make sure it is received
for _ in range(10):
self.pub_scenario_finished.publish(Bool(True))

rospy.sleep(0.1)

rospy.signal_shutdown("Finished all episodes of the current scenario")


Expand Down
4 changes: 2 additions & 2 deletions task_generator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ class Constants:
MAX_RESET_FAIL_TIMES = 3

class ObstacleManager:
DYNAMIC_OBSTACLES = 15
STATIC_OBSTACLES = 20
DYNAMIC_OBSTACLES = 0
STATIC_OBSTACLES = 0

OBSTACLE_MAX_RADIUS = 0.6

Expand Down
2 changes: 1 addition & 1 deletion task_generator/environments/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def move_robot(self, pos, name=None):
"""
raise NotImplementedError()

def spawn_robot(self):
def spawn_robot(self, complexity=1):
"""
Spawn a robot in the environment.
A position is not specified because the robot is moved at the
Expand Down
6 changes: 3 additions & 3 deletions task_generator/environments/flatland_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _spawn_random_obstacle(

self._obstacles_amount += 1

def spawn_robot(self, name, robot_name, namespace_appendix=None):
def spawn_robot(self, name, robot_name, namespace_appendix=None, complexity=1):
base_model_path = os.path.join(
rospkg.RosPack().get_path("arena-simulation-setup"),
"robot",
Expand Down Expand Up @@ -331,10 +331,10 @@ def _read_yaml(self, yaml_path):
with open(yaml_path, "r") as file:
return yaml.safe_load(file)

@abstractmethod
@staticmethod
def create_obs_name(number):
return "obs_" + str(number)

@abstractmethod
@staticmethod
def check_yaml_path(path):
return os.path.isfile(path)
15 changes: 11 additions & 4 deletions task_generator/manager/robot_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, namespace, map_manager, environment, robot_setup):
self.robot_setup = robot_setup
self.record_data = rospy.get_param('record_data', False)# and rospy.get_param('task_mode', 'scenario') == 'scenario'

self.position = self.start_pos

def set_up_robot(self):
if Utils.get_arena_type() == Constants.ArenaType.TRAINING:
self.robot_radius = rospy.get_param("robot_radius")
Expand Down Expand Up @@ -76,7 +78,7 @@ def _robot_name(self):

return self.namespace

def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None):
def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None, move_robot=True):
"""
The manager creates new start and goal position
when a task is reset, publishes the goal to
Expand All @@ -92,7 +94,9 @@ def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None):
rospy.set_param(os.path.join(self.namespace, "start"), str(list(self.start_pos)))

self.publish_goal(self.goal_pos)
self.move_robot_to_start()

if move_robot:
self.move_robot_to_start()

self.set_is_goal_reached(self.start_pos, self.goal_pos)

Expand All @@ -103,7 +107,7 @@ def reset(self, forbidden_zones=[], start_pos=None, goal_pos=None):
except:
pass

return self.start_pos, self.goal_pos
return self.position, self.goal_pos # self.start_pos, self.goal_pos

def publish_goal_periodically(self, _):
if self.goal_pos != None:
Expand Down Expand Up @@ -174,6 +178,7 @@ def launch_robot(self, robot_setup):
f"model:={robot_setup['model']}",
f"local_planner:={robot_setup['planner']}",
f"namespace:={self.namespace}",
f"complexity:={rospy.get_param('complexity', 1)}",
f"record_data:={self.record_data}",
*([f"agent_name:={robot_setup.get('agent')}"] if robot_setup.get('agent') else [])
]
Expand Down Expand Up @@ -207,8 +212,10 @@ def launch_robot(self, robot_setup):
def robot_pos_callback(self, data):
current_position = data.pose.pose.position

self.position = [current_position.x, current_position.y]

self.set_is_goal_reached(
[current_position.x, current_position.y],
self.position,
self.goal_pos
)

Expand Down
2 changes: 1 addition & 1 deletion task_generator/tasks/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _reset_robot_and_obstacles(
for manager in self.robot_managers:
for pos in manager.reset(
forbidden_zones=robot_positions
):
):
robot_positions.append(
[
pos[0],
Expand Down
57 changes: 28 additions & 29 deletions task_generator/tasks/staged.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,27 @@ def __init__(

def next_stage(self, _):
if self._curr_stage >= len(self._stages):
rospy.loginfo(f"({self.namespace}) INFO: Tried to trigger next stage but already reached last one")
rospy.loginfo(
f"({self.namespace}) INFO: Tried to trigger next stage but already reached last one"
)
return

self._curr_stage = self._curr_stage + 1

return self._init_stage_and_update_hyperparams(self._curr_stage)
return self._init_stage_and_update_config(self._curr_stage)

def previous_stage(self, _):
if self._curr_stage <= 1:
rospy.loginfo(f"({self.namespace}) INFO: Tried to trigger previous stage but already reached first one")
rospy.loginfo(
f"({self.namespace}) INFO: Tried to trigger previous stage but already reached first one"
)
return

self._curr_stage = self._curr_stage - 1

return self._init_stage_and_update_hyperparams(self._curr_stage)
return self._init_stage_and_update_config(self._curr_stage)

def _init_stage_and_update_hyperparams(self, stage):
def _init_stage_and_update_config(self, stage):
self._init_stage(stage)

if self.namespace != "eval_sim":
Expand All @@ -84,45 +88,40 @@ def _init_stage_and_update_hyperparams(self, stage):
rospy.set_param("/curr_stage", stage)
rospy.set_param("/last_state_reached", stage == len(self._stages))

self._update_stage_in_hyperparams(stage)
self._update_stage_in_config(stage)

return stage

def _init_debug_mode(self, paths):
if self._debug_mode:
return

self._hyperparams_file_path = os.path.join(
paths.get("model"), "hyperparameters.json"
)
self._hyperparams_lock = FileLock(f"{self._hyperparams_file_path}.lock")
self._config_file_path = paths["config"]
self._config_lock = FileLock(f"{self._config_file_path}.lock")

assert os.path.isfile(
self._hyperparams_file_path
), f"Found no 'hyperparameters.json' at {self._hyperparams_file_path}"
self._config_file_path
), f"Found no 'training_config.yaml' at {self._config_file_path}"

def _update_stage_in_hyperparams(self, stage):
def _update_stage_in_config(self, stage):
"""
The current stage is stored inside the hyperparams
The current stage is stored inside the config
file for when the training is stopped and later
continued, the correct stage can be restored.
"""
if self._debug_mode:
return

self._hyperparams_lock.acquire()

file = open(self._hyperparams_file_path, "r")

hyperparams = json.load(file)
hyperparams["curr_stage"] = stage
self._config_lock.acquire()

with open(self._hyperparams_file_path, "w", encoding="utf-8") as target:
json.dump(hyperparams, target, ensure_ascii=False, indent=4)
with open(self._config_file_path, "r", encoding="utf-8") as target:
config = yaml.load(target, Loader=yaml.FullLoader)
config["callbacks"]["training_curriculum"]["curr_stage"] = stage

file.close()
with open(self._config_file_path, "w", encoding="utf-8") as target:
yaml.dump(config, target, ensure_ascii=False, indent=4)

self._hyperparams_lock.release()
self._config_lock.release()

def reset(self):
super().reset(
Expand All @@ -131,14 +130,14 @@ def reset(self):
)

def _reset_robot_and_obstacles(
self, static_obstacles=None, dynamic_obstacles=None, **kwargs
):
self, static_obstacles=None, dynamic_obstacles=None, **kwargs
):
super()._reset_robot_and_obstacles(
static_obstacles=static_obstacles,
static_obstacles=static_obstacles,
dynamic_obstacles=dynamic_obstacles,
**kwargs
**kwargs,
)

def _init_stage(self, stage):
static_obstacles = self._stages[stage]["static"]
dynamic_obstacles = self._stages[stage]["dynamic"]
Expand Down

0 comments on commit 4d3250f

Please sign in to comment.