From d72e85b0e8a4b4291d7be3de5a6f1f994b5c3764 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 24 Sep 2024 23:27:16 -0700 Subject: [PATCH] Skinned mesh + SMPL example adjustments (#291) * Fix equality check, tweak SMPL visualizer * SMPL example adjustments * ruff --- examples/08_smpl_visualizer.py | 52 ++++++++------ examples/25_smpl_visualizer_skinned.py | 86 +++++++++++++---------- src/viser/_messages.py | 11 +-- src/viser/_scene_api.py | 15 +--- src/viser/_scene_handles.py | 9 ++- src/viser/client/src/MessageHandler.tsx | 30 ++++++-- src/viser/client/src/ThreeAssets.tsx | 38 ++++++++-- src/viser/client/src/WebsocketMessages.ts | 4 +- 8 files changed, 155 insertions(+), 90 deletions(-) diff --git a/examples/08_smpl_visualizer.py b/examples/08_smpl_visualizer.py index 8954f6d1..66584605 100644 --- a/examples/08_smpl_visualizer.py +++ b/examples/08_smpl_visualizer.py @@ -34,21 +34,21 @@ def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" body_dict = dict(**np.load(model_path, allow_pickle=True)) - self._J_regressor = body_dict["J_regressor"] - self._weights = body_dict["weights"] - self._v_template = body_dict["v_template"] - self._posedirs = body_dict["posedirs"] - self._shapedirs = body_dict["shapedirs"] - self._faces = body_dict["f"] - - self.num_joints: int = self._weights.shape[-1] - self.num_betas: int = self._shapedirs.shape[-1] + self.J_regressor = body_dict["J_regressor"] + self.weights = body_dict["weights"] + self.v_template = body_dict["v_template"] + self.posedirs = body_dict["posedirs"] + self.shapedirs = body_dict["shapedirs"] + self.faces = body_dict["f"] + + self.num_joints: int = self.weights.shape[-1] + self.num_betas: int = self.shapedirs.shape[-1] self.parent_idx: np.ndarray = body_dict["kintree_table"][0] def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: # Get shaped vertices + joint positions, when all local poses are identity. - v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) - j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) + v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas) + j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose) # Local SE(3) transforms. T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) @@ -63,13 +63,13 @@ def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutpu # Linear blend skinning. pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() - v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) + v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta) v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] v_posed = np.einsum( - "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta + "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta ) - return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) + return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint) def main(model_path: Path) -> None: @@ -86,6 +86,13 @@ def main(model_path: Path) -> None: num_joints=model.num_joints, parent_idx=model.parent_idx, ) + body_handle = server.scene.add_mesh_simple( + "/human", + model.v_template, + model.faces, + wireframe=gui_elements.gui_wireframe.value, + color=gui_elements.gui_rgb.value, + ) while True: # Do nothing if no change. time.sleep(0.02) @@ -94,7 +101,7 @@ def main(model_path: Path) -> None: gui_elements.changed = False - # Compute SMPL outputs. + # If anything has changed, re-compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), joint_rotmats=tf.SO3.exp( @@ -102,13 +109,12 @@ def main(model_path: Path) -> None: np.array([x.value for x in gui_elements.gui_joints]) ).as_matrix(), ) - server.scene.add_mesh_simple( - "/human", - smpl_outputs.vertices, - smpl_outputs.faces, - wireframe=gui_elements.gui_wireframe.value, - color=gui_elements.gui_rgb.value, - ) + + # Update the mesh properties based on the SMPL model output + GUI + # elements. + body_handle.vertices = smpl_outputs.vertices + body_handle.wireframe = gui_elements.gui_wireframe.value + body_handle.color = gui_elements.gui_rgb.value # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): @@ -146,7 +152,7 @@ def set_changed(_) -> None: with tab_group.add_tab("View", viser.Icon.VIEWFINDER): gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255)) gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) - gui_show_controls = server.gui.add_checkbox("Handles", initial_value=False) + gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True) gui_rgb.on_update(set_changed) gui_wireframe.on_update(set_changed) diff --git a/examples/25_smpl_visualizer_skinned.py b/examples/25_smpl_visualizer_skinned.py index c6c8d6d1..91faa511 100644 --- a/examples/25_smpl_visualizer_skinned.py +++ b/examples/25_smpl_visualizer_skinned.py @@ -40,21 +40,27 @@ def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" body_dict = dict(**np.load(model_path, allow_pickle=True)) - self._J_regressor = body_dict["J_regressor"] - self._weights = body_dict["weights"] - self._v_template = body_dict["v_template"] - self._posedirs = body_dict["posedirs"] - self._shapedirs = body_dict["shapedirs"] - self._faces = body_dict["f"] - - self.num_joints: int = self._weights.shape[-1] - self.num_betas: int = self._shapedirs.shape[-1] + self.J_regressor = body_dict["J_regressor"] + self.weights = body_dict["weights"] + self.v_template = body_dict["v_template"] + self.posedirs = body_dict["posedirs"] + self.shapedirs = body_dict["shapedirs"] + self.faces = body_dict["f"] + + self.num_joints: int = self.weights.shape[-1] + self.num_betas: int = self.shapedirs.shape[-1] self.parent_idx: np.ndarray = body_dict["kintree_table"][0] + def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + # Get shaped vertices + joint positions, when all local poses are identity. + v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas) + j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose) + return v_tpose, j_tpose + def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs: # Get shaped vertices + joint positions, when all local poses are identity. - v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas) - j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose) + v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas) + j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose) # Local SE(3) transforms. T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4) @@ -69,13 +75,13 @@ def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutpu # Linear blend skinning. pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten() - v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta) + v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta) v_delta = np.ones((v_blend.shape[0], self.num_joints, 4)) v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :] v_posed = np.einsum( - "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta + "jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta ) - return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint) + return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint) def main(model_path: Path) -> None: @@ -92,23 +98,14 @@ def main(model_path: Path) -> None: num_joints=model.num_joints, parent_idx=model.parent_idx, ) - smpl_outputs = model.get_outputs( - betas=np.array([x.value for x in gui_elements.gui_betas]), - joint_rotmats=np.zeros((model.num_joints, 3, 3)) + np.eye(3), - ) - - bone_wxyzs = np.array( - [tf.SO3.from_matrix(R).wxyz for R in smpl_outputs.T_world_joint[:, :3, :3]] - ) - bone_positions = smpl_outputs.T_world_joint[:, :3, 3] - - skinned_handle = server.scene.add_mesh_skinned( + v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,))) + mesh_handle = server.scene.add_mesh_skinned( "/human", - smpl_outputs.vertices, - smpl_outputs.faces, - bone_wxyzs=bone_wxyzs, - bone_positions=bone_positions, - skin_weights=model._weights, + v_tpose, + model.faces, + bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz, + bone_positions=j_tpose, + skin_weights=model.weights, wireframe=gui_elements.gui_wireframe.value, color=gui_elements.gui_rgb.value, ) @@ -119,10 +116,19 @@ def main(model_path: Path) -> None: if not gui_elements.changed: continue + # Shapes changed: update vertices / joint positions. + if gui_elements.betas_changed: + v_tpose, j_tpose = model.get_tpose( + np.array([gui_beta.value for gui_beta in gui_elements.gui_betas]) + ) + mesh_handle.vertices = v_tpose + mesh_handle.bone_positions = j_tpose + gui_elements.changed = False + gui_elements.betas_changed = False # Render as wireframe? - skinned_handle.wireframe = gui_elements.gui_wireframe.value + mesh_handle.wireframe = gui_elements.gui_wireframe.value # Compute SMPL outputs. smpl_outputs = model.get_outputs( @@ -139,10 +145,10 @@ def main(model_path: Path) -> None: # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] - skinned_handle.bones[i].wxyz = tf.SO3.from_matrix( + mesh_handle.bones[i].wxyz = tf.SO3.from_matrix( smpl_outputs.T_world_joint[i, :3, :3] ).wxyz - skinned_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3] + mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3] @dataclass @@ -156,7 +162,10 @@ class GuiElements: transform_controls: List[viser.TransformControlsHandle] changed: bool - """This flag will be flipped to True whenever the mesh needs to be re-generated.""" + """This flag will be flipped to True whenever any input is changed.""" + + betas_changed: bool + """This flag will be flipped to True whenever the shape changes.""" def make_gui_elements( @@ -170,7 +179,11 @@ def make_gui_elements( tab_group = server.gui.add_tab_group() def set_changed(_) -> None: - out.changed = True # out is define later! + out.changed = True # out is defined later! + + def set_betas_changed(_) -> None: + out.betas_changed = True + out.changed = True # GUI elements: mesh settings + visibility. with tab_group.add_tab("View", viser.Icon.VIEWFINDER): @@ -220,7 +233,7 @@ def _(_): f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0 ) gui_betas.append(beta) - beta.on_update(set_changed) + beta.on_update(set_betas_changed) # GUI elements: joint angles. with tab_group.add_tab("Joints", viser.Icon.ANGLE): @@ -295,6 +308,7 @@ def _(_) -> None: gui_joints, transform_controls=transform_controls, changed=True, + betas_changed=False, ) return out diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 555d745a..b06856fc 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -551,10 +551,10 @@ class SkinnedMeshProps(MeshProps): Vertices are internally canonicalized to float32, faces to uint32.""" - bone_wxyzs: Tuple[Tuple[float, float, float, float], ...] - """Tuple of quaternions representing bone orientations. Synchronized automatically when assigned.""" - bone_positions: Tuple[Tuple[float, float, float], ...] - """Tuple of positions representing bone positions. Synchronized automatically when assigned.""" + bone_wxyzs: npt.NDArray[np.float32] + """Array of quaternions representing bone orientations (B, 4). Synchronized automatically when assigned.""" + bone_positions: npt.NDArray[np.float32] + """Array of positions representing bone positions (B, 3). Synchronized automatically when assigned.""" skin_indices: npt.NDArray[np.uint16] """Array of skin indices. Should have shape (V, 4). Synchronized automatically when assigned.""" skin_weights: npt.NDArray[np.float32] @@ -562,6 +562,9 @@ class SkinnedMeshProps(MeshProps): def __post_init__(self): # Check shapes. + assert self.bone_wxyzs.shape[-1] == 4 + assert self.bone_positions.shape[-1] == 3 + assert self.bone_wxyzs.shape[0] == self.bone_positions.shape[0] assert self.vertices.shape[-1] == 3 assert self.faces.shape[-1] == 3 assert self.skin_weights is not None diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 507d0678..902f1688 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1065,19 +1065,8 @@ def add_mesh_skinned( flat_shading=flat_shading, side=side, material=material, - bone_wxyzs=tuple( - ( - float(wxyz[0]), - float(wxyz[1]), - float(wxyz[2]), - float(wxyz[3]), - ) - for wxyz in bone_wxyzs.astype(np.float32) - ), - bone_positions=tuple( - (float(xyz[0]), float(xyz[1]), float(xyz[2])) - for xyz in bone_positions.astype(np.float32) - ), + bone_wxyzs=bone_wxyzs.astype(np.float32), + bone_positions=bone_positions.astype(np.float32), skin_indices=top4_skin_indices.astype(np.uint16), skin_weights=top4_skin_weights.astype(np.float32), ), diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index 1cc9364a..0a721ed0 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -58,8 +58,13 @@ def __setattr__(self, name: str, value: Any) -> None: elif hint == onpt.NDArray[np.uint8] and "color" in name: value = colors_to_uint8(value) - if getattr(handle._impl.props, name) == value: - # Do nothing. Assumes equality is defined for the prop value. + current_value = getattr(handle._impl.props, name) + + # Do nothing if the value hasn't changed. + if isinstance(current_value, np.ndarray): + if current_value.data == value.data: + return + elif current_value == value: return setattr(handle._impl.props, name, value) diff --git a/src/viser/client/src/MessageHandler.tsx b/src/viser/client/src/MessageHandler.tsx index e6cd4a89..2d42c8fd 100644 --- a/src/viser/client/src/MessageHandler.tsx +++ b/src/viser/client/src/MessageHandler.tsx @@ -81,12 +81,34 @@ function useMessageHandler() { initialized: false, poses: [], }; + + const bone_wxyzs = new Float32Array( + message.props.bone_wxyzs.buffer.slice( + message.props.bone_wxyzs.byteOffset, + message.props.bone_wxyzs.byteOffset + + message.props.bone_wxyzs.byteLength, + ), + ); + const bone_positions = new Float32Array( + message.props.bone_positions.buffer.slice( + message.props.bone_positions.byteOffset, + message.props.bone_positions.byteOffset + + message.props.bone_positions.byteLength, + ), + ); for (let i = 0; i < message.props.bone_wxyzs!.length; i++) { - const wxyz = message.props.bone_wxyzs[i]; - const position = message.props.bone_positions[i]; viewer.skinnedMeshState.current[message.name].poses.push({ - wxyz: wxyz, - position: position, + wxyz: [ + bone_wxyzs[4 * i], + bone_wxyzs[4 * i + 1], + bone_wxyzs[4 * i + 2], + bone_wxyzs[4 * i + 3], + ], + position: [ + bone_positions[3 * i], + bone_positions[3 * i + 1], + bone_positions[3 * i + 2], + ], }); } } diff --git a/src/viser/client/src/ThreeAssets.tsx b/src/viser/client/src/ThreeAssets.tsx index 0a0ad98c..5c74e2a4 100644 --- a/src/viser/client/src/ThreeAssets.tsx +++ b/src/viser/client/src/ThreeAssets.tsx @@ -502,26 +502,52 @@ export const ViserMesh = React.forwardRef< let skeleton = undefined; if (message.type === "SkinnedMeshMessage") { // Skinned mesh. + const bone_wxyzs = new Float32Array( + message.props.bone_wxyzs.buffer.slice( + message.props.bone_wxyzs.byteOffset, + message.props.bone_wxyzs.byteOffset + + message.props.bone_wxyzs.byteLength, + ), + ); + const bone_positions = new Float32Array( + message.props.bone_positions.buffer.slice( + message.props.bone_positions.byteOffset, + message.props.bone_positions.byteOffset + + message.props.bone_positions.byteLength, + ), + ); + const bones: THREE.Bone[] = []; bonesRef.current = bones; - for (let i = 0; i < message.props.bone_wxyzs!.length; i++) { + for (let i = 0; i < bone_positions.length / 3; i++) { bones.push(new THREE.Bone()); } const boneInverses: THREE.Matrix4[] = []; const xyzw_quat = new THREE.Quaternion(); bones.forEach((bone, i) => { - const wxyz = message.props.bone_wxyzs[i]; - const position = message.props.bone_positions[i]; - xyzw_quat.set(wxyz[1], wxyz[2], wxyz[3], wxyz[0]); + xyzw_quat.set( + bone_wxyzs[i * 4 + 1], + bone_wxyzs[i * 4 + 2], + bone_wxyzs[i * 4 + 3], + bone_wxyzs[i * 4 + 0], + ); const boneInverse = new THREE.Matrix4(); boneInverse.makeRotationFromQuaternion(xyzw_quat); - boneInverse.setPosition(position[0], position[1], position[2]); + boneInverse.setPosition( + bone_positions[i * 3 + 0], + bone_positions[i * 3 + 1], + bone_positions[i * 3 + 2], + ); boneInverse.invert(); boneInverses.push(boneInverse); bone.quaternion.copy(xyzw_quat); - bone.position.set(position[0], position[1], position[2]); + bone.position.set( + bone_positions[i * 3 + 0], + bone_positions[i * 3 + 1], + bone_positions[i * 3 + 2], + ); }); skeleton = new THREE.Skeleton(bones, boneInverses); diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index c49e0804..8f34b0c3 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -347,8 +347,8 @@ export interface SkinnedMeshMessage { flat_shading: boolean; side: "front" | "back" | "double"; material: "standard" | "toon3" | "toon5"; - bone_wxyzs: [number, number, number, number][]; - bone_positions: [number, number, number][]; + bone_wxyzs: Uint8Array; + bone_positions: Uint8Array; skin_indices: Uint8Array; skin_weights: Uint8Array; };