Skip to content

Commit

Permalink
Added naive support and fixed bugs in the code (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
KDharmarajanDev authored Apr 29, 2024
1 parent 6c7599d commit b3b73a4
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,16 @@ def __init__(self, robot_name=None, ckpt_path=None, render=False, video_path=Non
config, _ = FileUtils.config_from_checkpoint(ckpt_dict=self.ckpt_dict)
self.rollout_horizon = config.experiment.rollout.horizon

self.num_robots = len(self.env.env.robots)
self.core_env = self.env.env
self.is_diffusion = False
while hasattr(self.core_env, "env"):
self.is_diffusion = True
self.core_env = self.core_env.env

self.num_robots = len(self.core_env.robots)
self.eef_site_name = []
for i in range(self.num_robots):
self.eef_site_name.append(self.env.env.robots[i].controller.eef_name)
self.eef_site_name.append(self.core_env.robots[i].controller.eef_name)


# maybe open hdf5 to write rollouts
Expand All @@ -184,7 +190,7 @@ def initialize_robot(self):
self.obs = self.env.reset()
state_dict_source = self.env.get_state()
self.obs = self.env.reset_to(state_dict_source) # necessary for robosuite tasks for deterministic action playback

def compute_pose_error(self, target_pose):
starting_pose = self.compute_eef_pose()
if self.num_robots == 1:
Expand All @@ -198,7 +204,7 @@ def compute_pose_error(self, target_pose):

def drive_robot_to_target_pose(self, target_pose=None, tracking_error_threshold=0.003, num_iter_max=100):
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = False # change to absolute pose for setting the initial state
self.core_env.robots[i].controller.use_delta = False # change to absolute pose for setting the initial state

error, starting_pose = self.compute_pose_error(target_pose)
num_iters = 0
Expand Down Expand Up @@ -234,15 +240,15 @@ def drive_robot_to_target_pose(self, target_pose=None, tracking_error_threshold=

# change back to delta pose
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = True
self.core_env.robots[i].controller.use_delta = True


def compute_eef_pose(self):
"""return a 7D or 14D pose vector"""
pose = []
for i in range(self.num_robots):
pos = np.array(self.env.env.sim.data.site_xpos[self.env.env.sim.model.site_name2id(self.eef_site_name[i])])
rot = np.array(T.mat2quat(self.env.env.sim.data.site_xmat[self.env.env.sim.model.site_name2id(self.eef_site_name[i])].reshape([3, 3])))
pos = np.array(self.core_env.sim.data.site_xpos[self.core_env.sim.model.site_name2id(self.eef_site_name[i])])
rot = np.array(T.mat2quat(self.core_env.sim.data.site_xmat[self.core_env.sim.model.site_name2id(self.eef_site_name[i])].reshape([3, 3])))
pose.append(np.concatenate((pos, rot)))
pose = np.concatenate(pose)
return pose
Expand All @@ -251,17 +257,19 @@ def compute_eef_pose(self):
def get_object_state(self):
object_state = dict()
for obj_name in TASK_OBJECT_DICT[self.task]:
object_state[obj_name] = self.env.env.sim.data.get_joint_qpos(obj_name)
object_state[obj_name] = self.core_env.sim.data.get_joint_qpos(obj_name)
return object_state

def set_object_state(self, set_to_target_object_state=None):
if set_to_target_object_state is not None:
# set target object to target object state
for obj_name in TASK_OBJECT_DICT[self.task]:
self.env.env.sim.data.set_joint_qpos(obj_name, set_to_target_object_state[obj_name])
# self.env.env.sim.data.set_joint_qpos("cube_joint0", set_to_target_object_state)
self.env.env.sim.forward()
self.core_env.sim.data.set_joint_qpos(obj_name, set_to_target_object_state[obj_name])
# self.core_env.sim.data.set_joint_qpos("cube_joint0", set_to_target_object_state)
self.core_env.sim.forward()
self.obs = self.env.get_observation()
if hasattr(self.env, "_get_stacked_obs_from_history"):
self.obs = self.env._get_stacked_obs_from_history()


def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=0.003, num_iter_max=100, goal_pose=None, name="Source Robot"):
Expand All @@ -277,7 +285,7 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=
if not blocking:
assert (len(action) == 7 and self.num_robots == 1) or (len(action) == 14 and self.num_robots == 2), "Action should be 7DOF"
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = use_delta
self.core_env.robots[i].controller.use_delta = use_delta
next_obs, r, done, _ = self.env.step(action) # just execute action
# if action[-1] != self.prev_action[-1]:
# print("after", self.compute_eef_pose())
Expand All @@ -292,20 +300,20 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=
# single_arm.py #L247
# osc.py #L264
if use_delta:
self.env.env.robots[0].controller.use_delta = True
self.core_env.robots[0].controller.use_delta = True
# convert to equivalent absolute actions
self.env.env.robots[0].controller.set_goal(action[:self.env.env.robots[0].controller.control_dim])
action_goal_pos = self.env.env.robots[0].controller.goal_pos
action_goal_ori = Rotation.from_matrix(self.env.env.robots[0].controller.goal_ori).as_rotvec()
self.core_env.robots[0].controller.set_goal(action[:self.core_env.robots[0].controller.control_dim])
action_goal_pos = self.core_env.robots[0].controller.goal_pos
action_goal_ori = Rotation.from_matrix(self.core_env.robots[0].controller.goal_ori).as_rotvec()
raise NotImplementedError
next_obs, r, done, _ = self.env.step(action)
else:
assert (len(action) == 8 and self.num_robots == 1) or (len(action) == 16 and self.num_robots == 2), "Action should be 8DOF"
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = False
self.env.env.robots[i].controller.kp = np.array([150, 150, 150, 150, 150, 150]) # control gain
# self.env.env.robots[i].controller.kp = np.array([100, 100, 500, 10, 10, 50]) # control gain
print("Robot {} controller kp: {}".format(i, self.env.env.robots[i].controller.kp))
self.core_env.robots[i].controller.use_delta = False
self.core_env.robots[i].controller.kp = np.array([150, 150, 150, 150, 150, 150]) # control gain
# self.core_env.robots[i].controller.kp = np.array([100, 100, 500, 10, 10, 50]) # control gain
print("Robot {} controller kp: {}".format(i, self.core_env.robots[i].controller.kp))
if self.num_robots == 1:
action_target = np.zeros(7)
action_target[:3] = action[:3]
Expand Down Expand Up @@ -337,12 +345,12 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=
if self.num_robots == 1:
if action[-1] != self.prev_action[-1]:
# print("after", self.compute_eef_pose())
# self.env.env.robots[0].controller.use_delta = True
# self.core_env.robots[0].controller.use_delta = True
# action_target = np.zeros(7)
action_target[-1] = action[-1]
# print("action_target", action_target)
next_obs, r, done, _ = self.env.step(action_target)
# self.env.env.robots[0].controller.use_delta = False
# self.core_env.robots[0].controller.use_delta = False
# print("after", self.compute_eef_pose())
elif self.num_robots == 2:
action_target[6] = action[7]
Expand Down Expand Up @@ -412,6 +420,7 @@ def run_experiments(self, seeds, rollout_num_episodes=1, video_skip=5, camera_na
for k in dict_rollout_stats:
avg_rollout_stats[k].append(np.mean(dict_rollout_stats[k]))
avg_rollout_stats["Num_Success"].append(np.sum(dict_rollout_stats["Success_Rate"]))
avg_rollout_stats["Num Rollouts"] = i + 1
avg_rollout_stats["Seeds"].append(seed)
avg_rollout_stats["Robot"] = self.robot_name
print("Average Rollout Stats:")
Expand Down Expand Up @@ -488,8 +497,6 @@ def __init__(self, robot_name=None, ckpt_path=None, render=False, video_path=Non
self.forward_dynamics_model = ForwardDynamicsModel(model_path=forward_dynamics_model_path)

self.naive = naive
if self.naive:
self.interpolator = GripperInterpolator('Panda', robot_name, [f'{self.save_stats_path}/gripper_interpolation_results_no_task_diff.pkl'])

def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_object_state=False, set_robot_pose=False, tracking_error_threshold=0.003, num_iter_max=100, target_robot_delta_action=False, demo_index=0):
"""
Expand All @@ -513,7 +520,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
stats (dict): some statistics for the rollout - such as return, horizon, and task success
traj (dict): dictionary that corresponds to the rollout trajectory
"""
assert isinstance(self.env, EnvBase)
# assert isinstance(self.env, EnvBase)
# assert isinstance(self.policy, RolloutPolicy)

if self.save_paired_images:
Expand Down Expand Up @@ -555,7 +562,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
# Pickle the object and send it to the server
data_string = pickle.dumps(variable)
message_length = struct.pack("!I", len(data_string))
self.conn.send(message_length)
self.conn.sendall(message_length)
self.conn.send(data_string)
# confirm that the target robot is ready
pickled_message_size = self._receive_all_bytes(4)
Expand Down Expand Up @@ -634,7 +641,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
segmentation_mask = cv2.flip(segmentation_mask, 0)
depth_normalized = obs['agentview_depth']
depth_normalized = cv2.flip(depth_normalized, 0)
depth_img = camera_utils.get_real_depth_map(self.env.env.sim, depth_normalized)
depth_img = camera_utils.get_real_depth_map(self.core_env.sim, depth_normalized)
# save the rgb image
cv2.imwrite(os.path.join(self.save_paired_images_folder_path, "franka_rgb", str(demo_index), "{}.jpg".format(step_i)), cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR) * 255)
# save the segmentation mask
Expand Down Expand Up @@ -663,11 +670,6 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
if np.isscalar(v):
obs_copy[k] = np.array([v])

# If using the naive policy, need to use gripper_interpolator
# if self.naive:
# gripper_angles = self.interpolator.interpolate_gripper(obs_copy['robot0_gripper_qpos'])
# obs_copy['robot0_gripper_qpos'] = gripper_angles

action = self.policy(ob=obs_copy) # get action from policy
gt_action = action.copy()

Expand Down Expand Up @@ -702,7 +704,23 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
# predicted_state[2] -= 0.015 # the z coordinate is too high
action = inpainted_img_action.copy()
# action = gt_action.copy()


if self.naive:
target_img = np.load(f"{self.save_stats_path}/naive_input.npy", allow_pickle=True)
obs_copy = deepcopy(obs)
obs_copy["agentview_image"] = target_img

agentview_img_new = obs_copy["agentview_image"][0]
agentview_img_new = agentview_img_new.transpose(1, 2, 0)
cv2.imwrite(f"{self.save_stats_path}/naive_input_1.png", cv2.cvtColor(agentview_img_new, cv2.COLOR_RGB2BGR) * 255)

agentview_img_new = obs_copy["agentview_image"][1]
agentview_img_new = agentview_img_new.transpose(1, 2, 0)
cv2.imwrite(f"{self.save_stats_path}/naive_input_2.png", cv2.cvtColor(agentview_img_new, cv2.COLOR_RGB2BGR) * 255)

action = self.policy(ob=obs_copy)
print("Action_Diff:", action - gt_action, np.linalg.norm(action - gt_action))

action, r, done, success = self.step(action, use_delta=self.control_delta, blocking=False, name="Source Robot")
if success:
has_succeeded = True
Expand Down Expand Up @@ -762,7 +780,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
# Pickle the object and send it to the server
data_string = pickle.dumps(variable)
message_length = struct.pack("!I", len(data_string))
self.conn.send(message_length)
self.conn.sendall(message_length)
self.conn.send(data_string)

# visualization
Expand Down Expand Up @@ -1065,6 +1083,6 @@ def _receive_all_bytes(self, num_bytes: int) -> bytes:
)
args = parser.parse_args()

source_robot = SourceRobot(robot_name=args.robot_name, ckpt_path=args.agent, render=args.render, video_path=args.video_path, rollout_horizon=args.horizon, seed=None, dataset_path=args.dataset_path, passive=args.passive, port=args.port, connection=args.connection, demo_path=args.demo_path, inpaint_enabled=args.inpaint_enabled, save_paired_images=args.save_paired_images, save_paired_images_folder_path=args.save_paired_images_folder_path, forward_dynamics_model_path=args.forward_dynamics_model_path, device=args.device, save_failed_demos=args.save_failed_demos, save_stats_path=args.save_stats_path)
source_robot = SourceRobot(robot_name=args.robot_name, ckpt_path=args.agent, render=args.render, video_path=args.video_path, rollout_horizon=args.horizon, seed=None, dataset_path=args.dataset_path, passive=args.passive, port=args.port, connection=args.connection, demo_path=args.demo_path, inpaint_enabled=args.inpaint_enabled, save_paired_images=args.save_paired_images, save_paired_images_folder_path=args.save_paired_images_folder_path, forward_dynamics_model_path=args.forward_dynamics_model_path, device=args.device, save_failed_demos=args.save_failed_demos, save_stats_path=args.save_stats_path, naive=args.naive)
source_robot.run_experiments(seeds=args.seeds, rollout_num_episodes=args.n_rollouts, video_skip=args.video_skip, camera_names=args.camera_names, dataset_obs=args.dataset_obs, save_stats_path=args.save_stats_path, tracking_error_threshold=args.tracking_error_threshold, num_iter_max=args.num_iter_max, inpaint_online_eval=args.inpaint_enabled)

Loading

0 comments on commit b3b73a4

Please sign in to comment.