Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added naive support and fixed bugs in the code #12

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,16 @@ def __init__(self, robot_name=None, ckpt_path=None, render=False, video_path=Non
config, _ = FileUtils.config_from_checkpoint(ckpt_dict=self.ckpt_dict)
self.rollout_horizon = config.experiment.rollout.horizon

self.num_robots = len(self.env.env.robots)
self.core_env = self.env.env
self.is_diffusion = False
while hasattr(self.core_env, "env"):
self.is_diffusion = True
self.core_env = self.core_env.env

self.num_robots = len(self.core_env.robots)
self.eef_site_name = []
for i in range(self.num_robots):
self.eef_site_name.append(self.env.env.robots[i].controller.eef_name)
self.eef_site_name.append(self.core_env.robots[i].controller.eef_name)


# maybe open hdf5 to write rollouts
Expand All @@ -184,7 +190,7 @@ def initialize_robot(self):
self.obs = self.env.reset()
state_dict_source = self.env.get_state()
self.obs = self.env.reset_to(state_dict_source) # necessary for robosuite tasks for deterministic action playback

def compute_pose_error(self, target_pose):
starting_pose = self.compute_eef_pose()
if self.num_robots == 1:
Expand All @@ -198,7 +204,7 @@ def compute_pose_error(self, target_pose):

def drive_robot_to_target_pose(self, target_pose=None, tracking_error_threshold=0.003, num_iter_max=100):
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = False # change to absolute pose for setting the initial state
self.core_env.robots[i].controller.use_delta = False # change to absolute pose for setting the initial state

error, starting_pose = self.compute_pose_error(target_pose)
num_iters = 0
Expand Down Expand Up @@ -234,15 +240,15 @@ def drive_robot_to_target_pose(self, target_pose=None, tracking_error_threshold=

# change back to delta pose
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = True
self.core_env.robots[i].controller.use_delta = True


def compute_eef_pose(self):
"""return a 7D or 14D pose vector"""
pose = []
for i in range(self.num_robots):
pos = np.array(self.env.env.sim.data.site_xpos[self.env.env.sim.model.site_name2id(self.eef_site_name[i])])
rot = np.array(T.mat2quat(self.env.env.sim.data.site_xmat[self.env.env.sim.model.site_name2id(self.eef_site_name[i])].reshape([3, 3])))
pos = np.array(self.core_env.sim.data.site_xpos[self.core_env.sim.model.site_name2id(self.eef_site_name[i])])
rot = np.array(T.mat2quat(self.core_env.sim.data.site_xmat[self.core_env.sim.model.site_name2id(self.eef_site_name[i])].reshape([3, 3])))
pose.append(np.concatenate((pos, rot)))
pose = np.concatenate(pose)
return pose
Expand All @@ -251,17 +257,19 @@ def compute_eef_pose(self):
def get_object_state(self):
object_state = dict()
for obj_name in TASK_OBJECT_DICT[self.task]:
object_state[obj_name] = self.env.env.sim.data.get_joint_qpos(obj_name)
object_state[obj_name] = self.core_env.sim.data.get_joint_qpos(obj_name)
return object_state

def set_object_state(self, set_to_target_object_state=None):
if set_to_target_object_state is not None:
# set target object to target object state
for obj_name in TASK_OBJECT_DICT[self.task]:
self.env.env.sim.data.set_joint_qpos(obj_name, set_to_target_object_state[obj_name])
# self.env.env.sim.data.set_joint_qpos("cube_joint0", set_to_target_object_state)
self.env.env.sim.forward()
self.core_env.sim.data.set_joint_qpos(obj_name, set_to_target_object_state[obj_name])
# self.core_env.sim.data.set_joint_qpos("cube_joint0", set_to_target_object_state)
self.core_env.sim.forward()
self.obs = self.env.get_observation()
if hasattr(self.env, "_get_stacked_obs_from_history"):
self.obs = self.env._get_stacked_obs_from_history()


def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=0.003, num_iter_max=100, goal_pose=None, name="Source Robot"):
Expand All @@ -277,7 +285,7 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=
if not blocking:
assert (len(action) == 7 and self.num_robots == 1) or (len(action) == 14 and self.num_robots == 2), "Action should be 7DOF"
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = use_delta
self.core_env.robots[i].controller.use_delta = use_delta
next_obs, r, done, _ = self.env.step(action) # just execute action
# if action[-1] != self.prev_action[-1]:
# print("after", self.compute_eef_pose())
Expand All @@ -292,20 +300,20 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=
# single_arm.py #L247
# osc.py #L264
if use_delta:
self.env.env.robots[0].controller.use_delta = True
self.core_env.robots[0].controller.use_delta = True
# convert to equivalent absolute actions
self.env.env.robots[0].controller.set_goal(action[:self.env.env.robots[0].controller.control_dim])
action_goal_pos = self.env.env.robots[0].controller.goal_pos
action_goal_ori = Rotation.from_matrix(self.env.env.robots[0].controller.goal_ori).as_rotvec()
self.core_env.robots[0].controller.set_goal(action[:self.core_env.robots[0].controller.control_dim])
action_goal_pos = self.core_env.robots[0].controller.goal_pos
action_goal_ori = Rotation.from_matrix(self.core_env.robots[0].controller.goal_ori).as_rotvec()
raise NotImplementedError
next_obs, r, done, _ = self.env.step(action)
else:
assert (len(action) == 8 and self.num_robots == 1) or (len(action) == 16 and self.num_robots == 2), "Action should be 8DOF"
for i in range(self.num_robots):
self.env.env.robots[i].controller.use_delta = False
self.env.env.robots[i].controller.kp = np.array([150, 150, 150, 150, 150, 150]) # control gain
# self.env.env.robots[i].controller.kp = np.array([100, 100, 500, 10, 10, 50]) # control gain
print("Robot {} controller kp: {}".format(i, self.env.env.robots[i].controller.kp))
self.core_env.robots[i].controller.use_delta = False
self.core_env.robots[i].controller.kp = np.array([150, 150, 150, 150, 150, 150]) # control gain
# self.core_env.robots[i].controller.kp = np.array([100, 100, 500, 10, 10, 50]) # control gain
print("Robot {} controller kp: {}".format(i, self.core_env.robots[i].controller.kp))
if self.num_robots == 1:
action_target = np.zeros(7)
action_target[:3] = action[:3]
Expand Down Expand Up @@ -337,12 +345,12 @@ def step(self, action, use_delta=True, blocking=False, tracking_error_threshold=
if self.num_robots == 1:
if action[-1] != self.prev_action[-1]:
# print("after", self.compute_eef_pose())
# self.env.env.robots[0].controller.use_delta = True
# self.core_env.robots[0].controller.use_delta = True
# action_target = np.zeros(7)
action_target[-1] = action[-1]
# print("action_target", action_target)
next_obs, r, done, _ = self.env.step(action_target)
# self.env.env.robots[0].controller.use_delta = False
# self.core_env.robots[0].controller.use_delta = False
# print("after", self.compute_eef_pose())
elif self.num_robots == 2:
action_target[6] = action[7]
Expand Down Expand Up @@ -412,6 +420,7 @@ def run_experiments(self, seeds, rollout_num_episodes=1, video_skip=5, camera_na
for k in dict_rollout_stats:
avg_rollout_stats[k].append(np.mean(dict_rollout_stats[k]))
avg_rollout_stats["Num_Success"].append(np.sum(dict_rollout_stats["Success_Rate"]))
avg_rollout_stats["Num Rollouts"] = i + 1
avg_rollout_stats["Seeds"].append(seed)
avg_rollout_stats["Robot"] = self.robot_name
print("Average Rollout Stats:")
Expand Down Expand Up @@ -488,8 +497,6 @@ def __init__(self, robot_name=None, ckpt_path=None, render=False, video_path=Non
self.forward_dynamics_model = ForwardDynamicsModel(model_path=forward_dynamics_model_path)

self.naive = naive
if self.naive:
self.interpolator = GripperInterpolator('Panda', robot_name, [f'{self.save_stats_path}/gripper_interpolation_results_no_task_diff.pkl'])

def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_object_state=False, set_robot_pose=False, tracking_error_threshold=0.003, num_iter_max=100, target_robot_delta_action=False, demo_index=0):
"""
Expand All @@ -513,7 +520,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
stats (dict): some statistics for the rollout - such as return, horizon, and task success
traj (dict): dictionary that corresponds to the rollout trajectory
"""
assert isinstance(self.env, EnvBase)
# assert isinstance(self.env, EnvBase)
# assert isinstance(self.policy, RolloutPolicy)

if self.save_paired_images:
Expand Down Expand Up @@ -555,7 +562,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
# Pickle the object and send it to the server
data_string = pickle.dumps(variable)
message_length = struct.pack("!I", len(data_string))
self.conn.send(message_length)
self.conn.sendall(message_length)
self.conn.send(data_string)
# confirm that the target robot is ready
pickled_message_size = self._receive_all_bytes(4)
Expand Down Expand Up @@ -634,7 +641,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
segmentation_mask = cv2.flip(segmentation_mask, 0)
depth_normalized = obs['agentview_depth']
depth_normalized = cv2.flip(depth_normalized, 0)
depth_img = camera_utils.get_real_depth_map(self.env.env.sim, depth_normalized)
depth_img = camera_utils.get_real_depth_map(self.core_env.sim, depth_normalized)
# save the rgb image
cv2.imwrite(os.path.join(self.save_paired_images_folder_path, "franka_rgb", str(demo_index), "{}.jpg".format(step_i)), cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR) * 255)
# save the segmentation mask
Expand Down Expand Up @@ -663,11 +670,6 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
if np.isscalar(v):
obs_copy[k] = np.array([v])

# If using the naive policy, need to use gripper_interpolator
# if self.naive:
# gripper_angles = self.interpolator.interpolate_gripper(obs_copy['robot0_gripper_qpos'])
# obs_copy['robot0_gripper_qpos'] = gripper_angles

action = self.policy(ob=obs_copy) # get action from policy
gt_action = action.copy()

Expand Down Expand Up @@ -702,7 +704,23 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
# predicted_state[2] -= 0.015 # the z coordinate is too high
action = inpainted_img_action.copy()
# action = gt_action.copy()


if self.naive:
target_img = np.load(f"{self.save_stats_path}/naive_input.npy", allow_pickle=True)
obs_copy = deepcopy(obs)
obs_copy["agentview_image"] = target_img

agentview_img_new = obs_copy["agentview_image"][0]
agentview_img_new = agentview_img_new.transpose(1, 2, 0)
cv2.imwrite(f"{self.save_stats_path}/naive_input_1.png", cv2.cvtColor(agentview_img_new, cv2.COLOR_RGB2BGR) * 255)

agentview_img_new = obs_copy["agentview_image"][1]
agentview_img_new = agentview_img_new.transpose(1, 2, 0)
cv2.imwrite(f"{self.save_stats_path}/naive_input_2.png", cv2.cvtColor(agentview_img_new, cv2.COLOR_RGB2BGR) * 255)

action = self.policy(ob=obs_copy)
print("Action_Diff:", action - gt_action, np.linalg.norm(action - gt_action))

action, r, done, success = self.step(action, use_delta=self.control_delta, blocking=False, name="Source Robot")
if success:
has_succeeded = True
Expand Down Expand Up @@ -762,7 +780,7 @@ def rollout_robot(self, video_skip=5, return_obs=False, camera_names=None, set_o
# Pickle the object and send it to the server
data_string = pickle.dumps(variable)
message_length = struct.pack("!I", len(data_string))
self.conn.send(message_length)
self.conn.sendall(message_length)
self.conn.send(data_string)

# visualization
Expand Down Expand Up @@ -1065,6 +1083,6 @@ def _receive_all_bytes(self, num_bytes: int) -> bytes:
)
args = parser.parse_args()

source_robot = SourceRobot(robot_name=args.robot_name, ckpt_path=args.agent, render=args.render, video_path=args.video_path, rollout_horizon=args.horizon, seed=None, dataset_path=args.dataset_path, passive=args.passive, port=args.port, connection=args.connection, demo_path=args.demo_path, inpaint_enabled=args.inpaint_enabled, save_paired_images=args.save_paired_images, save_paired_images_folder_path=args.save_paired_images_folder_path, forward_dynamics_model_path=args.forward_dynamics_model_path, device=args.device, save_failed_demos=args.save_failed_demos, save_stats_path=args.save_stats_path)
source_robot = SourceRobot(robot_name=args.robot_name, ckpt_path=args.agent, render=args.render, video_path=args.video_path, rollout_horizon=args.horizon, seed=None, dataset_path=args.dataset_path, passive=args.passive, port=args.port, connection=args.connection, demo_path=args.demo_path, inpaint_enabled=args.inpaint_enabled, save_paired_images=args.save_paired_images, save_paired_images_folder_path=args.save_paired_images_folder_path, forward_dynamics_model_path=args.forward_dynamics_model_path, device=args.device, save_failed_demos=args.save_failed_demos, save_stats_path=args.save_stats_path, naive=args.naive)
source_robot.run_experiments(seeds=args.seeds, rollout_num_episodes=args.n_rollouts, video_skip=args.video_skip, camera_names=args.camera_names, dataset_obs=args.dataset_obs, save_stats_path=args.save_stats_path, tracking_error_threshold=args.tracking_error_threshold, num_iter_max=args.num_iter_max, inpaint_online_eval=args.inpaint_enabled)

Loading
Loading