From b3b73a46548f2788595e57b29ac39133a222f1e8 Mon Sep 17 00:00:00 2001 From: Karthik Dharmarajan Date: Sun, 28 Apr 2024 17:43:01 -0700 Subject: [PATCH] Added naive support and fixed bugs in the code (#12) --- ...valuate_policy_demo_source_robot_server.py | 88 +++++++++++-------- ...valuate_policy_demo_target_robot_client.py | 60 ++++++------- .../robosuite/robosuite_experiment.py | 14 ++- .../robosuite/robosuite_experiment_config.py | 11 ++- .../robosuite/run_robosuite_benchmark.py | 11 +-- 5 files changed, 108 insertions(+), 76 deletions(-) diff --git a/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_source_robot_server.py b/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_source_robot_server.py index b003e0d..026a1df 100644 --- a/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_source_robot_server.py +++ b/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_source_robot_server.py @@ -154,10 +154,16 @@ def __init__(self, robot_name=None, ckpt_path=None, render=False, video_path=Non config, _ = FileUtils.config_from_checkpoint(ckpt_dict=self.ckpt_dict) self.rollout_horizon = config.experiment.rollout.horizon - self.num_robots = len(self.env.env.robots) + self.core_env = self.env.env + self.is_diffusion = False + while hasattr(self.core_env, "env"): + self.is_diffusion = True + self.core_env = self.core_env.env + + self.num_robots = len(self.core_env.robots) self.eef_site_name = [] for i in range(self.num_robots): - self.eef_site_name.append(self.env.env.robots[i].controller.eef_name) + self.eef_site_name.append(self.core_env.robots[i].controller.eef_name) # maybe open hdf5 to write rollouts @@ -184,7 +190,7 @@ def initialize_robot(self): self.obs = self.env.reset() state_dict_source = self.env.get_state() self.obs = self.env.reset_to(state_dict_source) # necessary for robosuite tasks for deterministic action playback - + def compute_pose_error(self, target_pose): starting_pose = self.compute_eef_pose() if self.num_robots == 1: @@ -198,7 +204,7 @@ def compute_pose_error(self, target_pose): def drive_robot_to_target_pose(self, target_pose=None, tracking_error_threshold=0.003, num_iter_max=100): for i in range(self.num_robots): - self.env.env.robots[i].controller.use_delta = False # change to absolute pose for setting the initial state + self.core_env.robots[i].controller.use_delta = False # change to absolute pose for setting the initial state error, starting_pose = self.compute_pose_error(target_pose) num_iters = 0 @@ -234,15 +240,15 @@ def drive_robot_to_target_pose(self, target_pose=None, tracking_error_threshold= # change back to delta pose for i in range(self.num_robots): - self.env.env.robots[i].controller.use_delta = True + self.core_env.robots[i].controller.use_delta = True def compute_eef_pose(self): """return a 7D or 14D pose vector""" pose = [] for i in range(self.num_robots): - pos = np.array(self.env.env.sim.data.site_xpos[self.env.env.sim.model.site_name2id(self.eef_site_name[i])]) - rot = np.array(T.mat2quat(self.env.env.sim.data.site_xmat[self.env.env.sim.model.site_name2id(self.eef_site_name[i])].reshape([3, 3]))) + pos = np.array(self.core_env.sim.data.site_xpos[self.core_env.sim.model.site_name2id(self.eef_site_name[i])]) + rot = np.array(T.mat2quat(self.core_env.sim.data.site_xmat[self.core_env.sim.model.site_name2id(self.eef_site_name[i])].reshape([3, 3]))) pose.append(np.concatenate((pos, rot))) pose = np.concatenate(pose) return pose @@ -251,17 +257,19 @@ def compute_eef_pose(self): def get_object_state(self): object_state = dict() for obj_name in TASK_OBJECT_DICT[self.task]: - object_state[obj_name] = self.env.env.sim.data.get_joint_qpos(obj_name) + object_state[obj_name] = self.core_env.sim.data.get_joint_qpos(obj_name) return object_state def set_object_state(self, set_to_target_object_state=None): if set_to_target_object_state is not None: # set target object to target object state for obj_name in TASK_OBJECT_DICT[self.task]: - self.env.env.sim.data.set_joint_qpos(obj_name, set_to_target_object_state[obj_name]) - # self.env.env.sim.data.set_joint_qpos("cube_joint0", set_to_target_object_state) - self.env.env.sim.forward() + self.core_env.sim.data.set_joint_qpos(obj_name, set_to_target_object_state[obj_name]) + # self.core_env.sim.data.set_joint_qpos("cube_joint0", set_to_target_object_state) + self.core_env.sim.forward() self.obs = self.env.get_observation() + if hasattr(self.env, "_get_stacked_obs_from_history"): + self.obs = self.env._get_stacked_obs_from_history() def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=0.003, num_iter_max=100, goal_pose=None, name="Source Robot"): @@ -277,7 +285,7 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold= if not blocking: assert (len(action) == 7 and self.num_robots == 1) or (len(action) == 14 and self.num_robots == 2), "Action should be 7DOF" for i in range(self.num_robots): - self.env.env.robots[i].controller.use_delta = use_delta + self.core_env.robots[i].controller.use_delta = use_delta next_obs, r, done, _ = self.env.step(action) # just execute action # if action[-1] != self.prev_action[-1]: # print("after", self.compute_eef_pose()) @@ -292,20 +300,20 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold= # single_arm.py #L247 # osc.py #L264 if use_delta: - self.env.env.robots[0].controller.use_delta = True + self.core_env.robots[0].controller.use_delta = True # convert to equivalent absolute actions - self.env.env.robots[0].controller.set_goal(action[:self.env.env.robots[0].controller.control_dim]) - action_goal_pos = self.env.env.robots[0].controller.goal_pos - action_goal_ori = Rotation.from_matrix(self.env.env.robots[0].controller.goal_ori).as_rotvec() + self.core_env.robots[0].controller.set_goal(action[:self.core_env.robots[0].controller.control_dim]) + action_goal_pos = self.core_env.robots[0].controller.goal_pos + action_goal_ori = Rotation.from_matrix(self.core_env.robots[0].controller.goal_ori).as_rotvec() raise NotImplementedError next_obs, r, done, _ = self.env.step(action) else: assert (len(action) == 8 and self.num_robots == 1) or (len(action) == 16 and self.num_robots == 2), "Action should be 8DOF" for i in range(self.num_robots): - self.env.env.robots[i].controller.use_delta = False - self.env.env.robots[i].controller.kp = np.array([150, 150, 150, 150, 150, 150]) # control gain - # self.env.env.robots[i].controller.kp = np.array([100, 100, 500, 10, 10, 50]) # control gain - print("Robot {} controller kp: {}".format(i, self.env.env.robots[i].controller.kp)) + self.core_env.robots[i].controller.use_delta = False + self.core_env.robots[i].controller.kp = np.array([150, 150, 150, 150, 150, 150]) # control gain + # self.core_env.robots[i].controller.kp = np.array([100, 100, 500, 10, 10, 50]) # control gain + print("Robot {} controller kp: {}".format(i, self.core_env.robots[i].controller.kp)) if self.num_robots == 1: action_target = np.zeros(7) action_target[:3] = action[:3] @@ -337,12 +345,12 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold= if self.num_robots == 1: if action[-1] != self.prev_action[-1]: # print("after", self.compute_eef_pose()) - # self.env.env.robots[0].controller.use_delta = True + # self.core_env.robots[0].controller.use_delta = True # action_target = np.zeros(7) action_target[-1] = action[-1] # print("action_target", action_target) next_obs, r, done, _ = self.env.step(action_target) - # self.env.env.robots[0].controller.use_delta = False + # self.core_env.robots[0].controller.use_delta = False # print("after", self.compute_eef_pose()) elif self.num_robots == 2: action_target[6] = action[7] @@ -412,6 +420,7 @@ def run_experiments(self, seeds, rollout_num_episodes=1, video_skip=5, camera_na for k in dict_rollout_stats: avg_rollout_stats[k].append(np.mean(dict_rollout_stats[k])) avg_rollout_stats["Num_Success"].append(np.sum(dict_rollout_stats["Success_Rate"])) + avg_rollout_stats["Num Rollouts"] = i + 1 avg_rollout_stats["Seeds"].append(seed) avg_rollout_stats["Robot"] = self.robot_name print("Average Rollout Stats:") @@ -488,8 +497,6 @@ def __init__(self, robot_name=None, ckpt_path=None, render=False, video_path=Non self.forward_dynamics_model = ForwardDynamicsModel(model_path=forward_dynamics_model_path) self.naive = naive - if self.naive: - self.interpolator = GripperInterpolator('Panda', robot_name, [f'{self.save_stats_path}/gripper_interpolation_results_no_task_diff.pkl']) def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_object_state=False, set_robot_pose=False, tracking_error_threshold=0.003, num_iter_max=100, target_robot_delta_action=False, demo_index=0): """ @@ -513,7 +520,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o stats (dict): some statistics for the rollout - such as return, horizon, and task success traj (dict): dictionary that corresponds to the rollout trajectory """ - assert isinstance(self.env, EnvBase) + # assert isinstance(self.env, EnvBase) # assert isinstance(self.policy, RolloutPolicy) if self.save_paired_images: @@ -555,7 +562,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o # Pickle the object and send it to the server data_string = pickle.dumps(variable) message_length = struct.pack("!I", len(data_string)) - self.conn.send(message_length) + self.conn.sendall(message_length) self.conn.send(data_string) # confirm that the target robot is ready pickled_message_size = self._receive_all_bytes(4) @@ -634,7 +641,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o segmentation_mask = cv2.flip(segmentation_mask, 0) depth_normalized = obs['agentview_depth'] depth_normalized = cv2.flip(depth_normalized, 0) - depth_img = camera_utils.get_real_depth_map(self.env.env.sim, depth_normalized) + depth_img = camera_utils.get_real_depth_map(self.core_env.sim, depth_normalized) # save the rgb image cv2.imwrite(os.path.join(self.save_paired_images_folder_path, "franka_rgb", str(demo_index), "{}.jpg".format(step_i)), cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR) * 255) # save the segmentation mask @@ -663,11 +670,6 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o if np.isscalar(v): obs_copy[k] = np.array([v]) - # If using the naive policy, need to use gripper_interpolator - # if self.naive: - # gripper_angles = self.interpolator.interpolate_gripper(obs_copy['robot0_gripper_qpos']) - # obs_copy['robot0_gripper_qpos'] = gripper_angles - action = self.policy(ob=obs_copy) # get action from policy gt_action = action.copy() @@ -702,7 +704,23 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o # predicted_state[2] -= 0.015 # the z coordinate is too high action = inpainted_img_action.copy() # action = gt_action.copy() - + + if self.naive: + target_img = np.load(f"{self.save_stats_path}/naive_input.npy", allow_pickle=True) + obs_copy = deepcopy(obs) + obs_copy["agentview_image"] = target_img + + agentview_img_new = obs_copy["agentview_image"][0] + agentview_img_new = agentview_img_new.transpose(1, 2, 0) + cv2.imwrite(f"{self.save_stats_path}/naive_input_1.png", cv2.cvtColor(agentview_img_new, cv2.COLOR_RGB2BGR) * 255) + + agentview_img_new = obs_copy["agentview_image"][1] + agentview_img_new = agentview_img_new.transpose(1, 2, 0) + cv2.imwrite(f"{self.save_stats_path}/naive_input_2.png", cv2.cvtColor(agentview_img_new, cv2.COLOR_RGB2BGR) * 255) + + action = self.policy(ob=obs_copy) + print("Action_Diff:", action - gt_action, np.linalg.norm(action - gt_action)) + action, r, done, success = self.step(action, use_delta=self.control_delta, blocking=False, name="Source Robot") if success: has_succeeded = True @@ -762,7 +780,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o # Pickle the object and send it to the server data_string = pickle.dumps(variable) message_length = struct.pack("!I", len(data_string)) - self.conn.send(message_length) + self.conn.sendall(message_length) self.conn.send(data_string) # visualization @@ -1065,6 +1083,6 @@ def _receive_all_bytes(self, num_bytes: int) -> bytes: ) args = parser.parse_args() - source_robot = SourceRobot(robot_name=args.robot_name, ckpt_path=args.agent, render=args.render, video_path=args.video_path, rollout_horizon=args.horizon, seed=None, dataset_path=args.dataset_path, passive=args.passive, port=args.port, connection=args.connection, demo_path=args.demo_path, inpaint_enabled=args.inpaint_enabled, save_paired_images=args.save_paired_images, save_paired_images_folder_path=args.save_paired_images_folder_path, forward_dynamics_model_path=args.forward_dynamics_model_path, device=args.device, save_failed_demos=args.save_failed_demos, save_stats_path=args.save_stats_path) + source_robot = SourceRobot(robot_name=args.robot_name, ckpt_path=args.agent, render=args.render, video_path=args.video_path, rollout_horizon=args.horizon, seed=None, dataset_path=args.dataset_path, passive=args.passive, port=args.port, connection=args.connection, demo_path=args.demo_path, inpaint_enabled=args.inpaint_enabled, save_paired_images=args.save_paired_images, save_paired_images_folder_path=args.save_paired_images_folder_path, forward_dynamics_model_path=args.forward_dynamics_model_path, device=args.device, save_failed_demos=args.save_failed_demos, save_stats_path=args.save_stats_path, naive=args.naive) source_robot.run_experiments(seeds=args.seeds, rollout_num_episodes=args.n_rollouts, video_skip=args.video_skip, camera_names=args.camera_names, dataset_obs=args.dataset_obs, save_stats_path=args.save_stats_path, tracking_error_threshold=args.tracking_error_threshold, num_iter_max=args.num_iter_max, inpaint_online_eval=args.inpaint_enabled) diff --git a/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_target_robot_client.py b/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_target_robot_client.py index 86a4320..ebd5d45 100644 --- a/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_target_robot_client.py +++ b/mirage/mirage/benchmark/robosuite/evaluate_policy_demo_target_robot_client.py @@ -56,10 +56,10 @@ def image_to_pointcloud(self, depth_map, camera_name, camera_height=84, camera_w """ Convert depth image to point cloud """ - real_depth_map = camera_utils.get_real_depth_map(self.env.env.sim, depth_map) + real_depth_map = camera_utils.get_real_depth_map(self.core_env.sim, depth_map) # Camera transform matrix to project from camera coordinates to world coordinates. - extrinsic_matrix = camera_utils.get_camera_extrinsic_matrix(self.env.env.sim, camera_name=camera_name) - intrinsic_matrix = camera_utils.get_camera_intrinsic_matrix(self.env.env.sim, camera_name=camera_name, camera_height=camera_height, camera_width=camera_width) + extrinsic_matrix = camera_utils.get_camera_extrinsic_matrix(self.core_env.sim, camera_name=camera_name) + intrinsic_matrix = camera_utils.get_camera_intrinsic_matrix(self.core_env.sim, camera_name=camera_name, camera_height=camera_height, camera_width=camera_width) # Convert depth image to point cloud points = [] # 3D points in robot frame of shape […, 3] @@ -74,8 +74,8 @@ def image_to_pointcloud(self, depth_map, camera_name, camera_height=84, camera_w return points def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_object_state=False, set_robot_pose=False, tracking_error_threshold=0.003, num_iter_max=100, target_robot_delta_action=False, demo_index=0): - - assert isinstance(self.env, EnvBase) + print(type(self.env)) + # assert isinstance(self.env, EnvBase) self.initialize_robot() @@ -105,7 +105,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o # Pickle the object and send it to the server data_string = pickle.dumps(variable) message_length = struct.pack("!I", len(data_string)) - self.s.send(message_length) + self.s.sendall(message_length) self.s.send(data_string) video_count = 0 # video frame counter @@ -144,7 +144,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o segmentation_mask = cv2.flip(segmentation_mask, 0) depth_normalized = obs['agentview_depth'] depth_normalized = cv2.flip(depth_normalized, 0) - depth_img = camera_utils.get_real_depth_map(self.env.env.sim, depth_normalized) + depth_img = camera_utils.get_real_depth_map(self.core_env.sim, depth_normalized) # save the rgb image cv2.imwrite(os.path.join(self.save_paired_images_folder_path, "{}_rgb".format(self.robot_name.lower()), str(demo_index), "{}.jpg".format(step_i)), cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR) * 255) # save the segmentation mask @@ -170,7 +170,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o depth_normalized = obs['agentview_depth'] depth_normalized = cv2.flip(depth_normalized, 0) - depth_img = camera_utils.get_real_depth_map(self.env.env.sim, depth_normalized) + depth_img = camera_utils.get_real_depth_map(self.core_env.sim, depth_normalized) points = self.image_to_pointcloud(depth_normalized, "agentview", 84, 84, segmask=None) # may need to change to 256, 256 timestep_info_dict = { @@ -180,8 +180,8 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o "seg": segmentation_mask, "real_depth_map": depth_img, "points": points, - "extrinsic_matrix": camera_utils.get_camera_extrinsic_matrix(self.env.env.sim, camera_name="agentview"), - "intrinsic_matrix": camera_utils.get_camera_intrinsic_matrix(self.env.env.sim, camera_name="agentview", camera_height=84, camera_width=84), + "extrinsic_matrix": camera_utils.get_camera_extrinsic_matrix(self.core_env.sim, camera_name="agentview"), + "intrinsic_matrix": camera_utils.get_camera_intrinsic_matrix(self.core_env.sim, camera_name="agentview", camera_height=84, camera_width=84), }, "low_dim": { "joint_angles": joint_angles, @@ -211,9 +211,6 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o inpainted_image = self.ros_inpaint_publisher.get_inpainted_image(True) inpainted_image = inpainted_image.astype(np.float32) / 255.0 print("Received inpainted image") - - if self.naive: - inpainted_image = rgb_img if self.use_diffusion: if self.diffusion_input == "target_robot": @@ -244,12 +241,15 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o if self.inpaint_writer is not None: self.inpaint_writer.append_data((inpainted_image * 255).astype(np.uint8)) + if self.naive: + rgb_img = obs['agentview_image'] + np.save(f"{self.save_stats_path}/naive_input.npy", rgb_img, allow_pickle=True) # Pickle the object and send it to the server data_string = pickle.dumps(variable) message_length = struct.pack("!I", len(data_string)) - self.s.send(message_length) - self.s.send(data_string) + self.s.sendall(message_length) + self.s.sendall(data_string) # receive target object state and target robot pose from target robot @@ -290,13 +290,13 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o # append gripper action action = np.concatenate([action, source_env_robot_state.action[-1:]]) if self.inpaint_enabled: - inpaint_action = timestep_info_dict["source_robot"]["inpainting"]["predicted_state"][:7] # only get position and quarternion of the target state and not the gripper part - inpaint_action = np.concatenate([inpaint_action, timestep_info_dict["source_robot"]["inpainting"]["predicted_action"][-1:]]) + # inpaint_action = timestep_info_dict["source_robot"]["inpainting"]["predicted_state"][:7] # only get position and quarternion of the target state and not the gripper part + # inpaint_action = np.concatenate([inpaint_action, timestep_info_dict["source_robot"]["inpainting"]["predicted_action"][-1:]]) inpaint_action = timestep_info_dict["source_robot"]["ground_truth"]["target_state"][:7] # only get position and quarternion of the target state and not the gripper part inpaint_action = np.concatenate([inpaint_action, timestep_info_dict["source_robot"]["inpainting"]["predicted_action"][-1:]]) - predicted_state_with_gt_action = np.concatenate([timestep_info_dict["source_robot"]["inpainting"]["predicted_state_from_gt"][:7], action[-1:]]) + # predicted_state_with_gt_action = np.concatenate([timestep_info_dict["source_robot"]["inpainting"]["predicted_state_from_gt"][:7], action[-1:]]) timestep_info_dict["ground_truth_action"] = action timestep_info_dict["inpaint_action"] = inpaint_action @@ -307,7 +307,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o action = inpaint_action print("Use predicted_state_with_gt_action action") # print("Predicted from gt action: ", predicted_state_with_gt_action) - action = predicted_state_with_gt_action + # action = predicted_state_with_gt_action else: action_0, action_1 = source_env_robot_state.robot_pose[:7], source_env_robot_state.robot_pose[7:] # append gripper action @@ -321,15 +321,15 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o trajectory_timestep_infos.append(timestep_info_dict) # if the file does not exist: - if not os.path.exists(self.inpaint_data_for_analysis_path_temp): - # create the file and save the trajectory_timestep_infos - np.save(self.inpaint_data_for_analysis_path_temp, [trajectory_timestep_infos], allow_pickle=True) - # else, append the timestep_info_dict to the last trajectory in the file - else: - trajectory_timestep_infos_temp = np.load(self.inpaint_data_for_analysis_path_temp, allow_pickle=True) - trajectory_timestep_infos_temp = trajectory_timestep_infos_temp.tolist() - trajectory_timestep_infos_temp[-1].append(timestep_info_dict) - np.save(self.inpaint_data_for_analysis_path_temp, trajectory_timestep_infos_temp, allow_pickle=True) + # if not os.path.exists(self.inpaint_data_for_analysis_path_temp): + # # create the file and save the trajectory_timestep_infos + # np.save(self.inpaint_data_for_analysis_path_temp, [trajectory_timestep_infos], allow_pickle=True) + # # else, append the timestep_info_dict to the last trajectory in the file + # else: + # trajectory_timestep_infos_temp = np.load(self.inpaint_data_for_analysis_path_temp, allow_pickle=True) + # trajectory_timestep_infos_temp = trajectory_timestep_infos_temp.tolist() + # trajectory_timestep_infos_temp[-1].append(timestep_info_dict) + # np.save(self.inpaint_data_for_analysis_path_temp, trajectory_timestep_infos_temp, allow_pickle=True) if success: has_succeeded = True next_obs = deepcopy(self.obs) @@ -351,11 +351,9 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o # Pickle the object and send it to the server data_string = pickle.dumps(variable) message_length = struct.pack("!I", len(data_string)) - self.s.send(message_length) + self.s.sendall(message_length) self.s.send(data_string) - - # visualization if self.render: self.env.render(mode="human", camera_name=camera_names[0]) # on-screen rendering can only support one camera diff --git a/mirage/mirage/benchmark/robosuite/robosuite_experiment.py b/mirage/mirage/benchmark/robosuite/robosuite_experiment.py index dc75007..a7a6389 100644 --- a/mirage/mirage/benchmark/robosuite/robosuite_experiment.py +++ b/mirage/mirage/benchmark/robosuite/robosuite_experiment.py @@ -43,7 +43,8 @@ def launch(self, override=False) -> None: "--num_iter_max", str(self._config.source_num_iter_max), "--horizon", str(self._config.horizon), "--robot_name", self._config.source_robot_name, - "--save_stats_path", os.path.join(self._config.results_folder, "source.txt") + "--save_stats_path", os.path.join(self._config.results_folder, "source.txt"), + "--device", self._config.device ] target_agent_args = ["python3", "evaluate_policy_demo_target_robot_client.py", @@ -54,15 +55,20 @@ def launch(self, override=False) -> None: "--num_iter_max", str(self._config.target_num_iter_max), "--horizon", str(self._config.horizon), "--robot_name", self._config.target_robot_name, - "--save_stats_path", os.path.join(self._config.results_folder, "target.txt") + "--save_stats_path", os.path.join(self._config.results_folder, "target.txt"), + "--device", self._config.device ] + if self._config.naive: + source_agent_args.append("--naive") + target_agent_args.append("--naive") + if self._config.source_gripper_type: - source_agent_args.append("--gripper_type") + source_agent_args.append("--gripper") source_agent_args.append(self._config.source_gripper_type) if self._config.target_gripper_type: - target_agent_args.append("--gripper_type") + target_agent_args.append("--gripper") target_agent_args.append(self._config.target_gripper_type) if self._config.connection: diff --git a/mirage/mirage/benchmark/robosuite/robosuite_experiment_config.py b/mirage/mirage/benchmark/robosuite/robosuite_experiment_config.py index 260c11f..90174f2 100644 --- a/mirage/mirage/benchmark/robosuite/robosuite_experiment_config.py +++ b/mirage/mirage/benchmark/robosuite/robosuite_experiment_config.py @@ -18,6 +18,7 @@ class ExperimentRobotsuiteConfig(ExperimentConfig): seed: int passive: bool connection: bool + naive: bool source_robot_name: str target_robot_name: str @@ -50,6 +51,9 @@ class ExperimentRobotsuiteConfig(ExperimentConfig): source_gripper_type: Optional[str] = None target_gripper_type: Optional[str] = None + # Optional device for evaluation + device: Optional[str] = None + def validate_config(self): """ Validates the configuration to see if the values are feasible. @@ -88,6 +92,7 @@ def __str__(self): table.add_row(["Seed", self.seed]) table.add_row(["Passive", self.passive]) table.add_row(["Connection", self.connection]) + table.add_row(["Naive", self.naive]) table.add_row(["Source Robot Name", self.source_robot_name]) table.add_row(["Target Robot Name", self.target_robot_name]) table.add_row(["Source Tracking Error Threshold", self.source_tracking_error_threshold]) @@ -104,6 +109,8 @@ def __str__(self): table.add_row(["Target Video Path", self.target_video_path]) table.add_row(["Source Gripper Type", self.source_gripper_type]) table.add_row(["Target Gripper Type", self.target_gripper_type]) + table.add_row(["Results Folder", self.results_folder]) + table.add_row(["Device", self.device]) return table.get_formatted_string() @staticmethod @@ -122,6 +129,7 @@ def from_yaml(yaml_file: str): seed=config["seed"], passive=config["passive"], connection=config["connection"], + naive=config["naive"], source_robot_name=config["source_robot_name"], target_robot_name=config["target_robot_name"], source_tracking_error_threshold=config["source_tracking_error_threshold"], @@ -138,5 +146,6 @@ def from_yaml(yaml_file: str): source_video_path=config.get("source_video_path"), target_video_path=config.get("target_video_path"), source_gripper_type=config.get("source_gripper_type"), - target_gripper_type=config.get("target_gripper_type") + target_gripper_type=config.get("target_gripper_type"), + device=config.get("device", "cuda") ) diff --git a/mirage/mirage/benchmark/robosuite/run_robosuite_benchmark.py b/mirage/mirage/benchmark/robosuite/run_robosuite_benchmark.py index 55ea7a3..c31a3b2 100644 --- a/mirage/mirage/benchmark/robosuite/run_robosuite_benchmark.py +++ b/mirage/mirage/benchmark/robosuite/run_robosuite_benchmark.py @@ -5,13 +5,14 @@ def main(): parser = argparse.ArgumentParser(description="Mirage Robosuite Benchmark") - parser.add_argument("--config_file", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("-y", action="store_true") args = parser.parse_args() - print("Loading config from: ", args.config_file) - config = ExperimentRobotsuiteConfig.from_yaml(args.config_file) + print("Loading config from: ", args.config) + config = ExperimentRobotsuiteConfig.from_yaml(args.config) print(config) - should_launch = input("Launch the experiment? [Y/n] ") + should_launch = "y" if args.y else input("Launch the experiment? [Y/n] ") if should_launch.lower() != "y": print("Exiting...") return @@ -21,7 +22,7 @@ def main(): try: new_experiment.launch() except ValueError as e: - should_override = input("Results folder already exists. Override? [Y/n] ") + should_override = "y" if args.y else input("Results folder already exists. Override? [Y/n] ") if should_override.lower() != "y": print("Exiting...") return