Skip to content

Commit

Permalink
Skinned mesh + SMPL example adjustments (#291)
Browse files Browse the repository at this point in the history
* Fix equality check, tweak SMPL visualizer

* SMPL example adjustments

* ruff
  • Loading branch information
brentyi committed Sep 25, 2024
1 parent a2f4abd commit d72e85b
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 90 deletions.
52 changes: 29 additions & 23 deletions examples/08_smpl_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ def __init__(self, model_path: Path) -> None:
assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
body_dict = dict(**np.load(model_path, allow_pickle=True))

self._J_regressor = body_dict["J_regressor"]
self._weights = body_dict["weights"]
self._v_template = body_dict["v_template"]
self._posedirs = body_dict["posedirs"]
self._shapedirs = body_dict["shapedirs"]
self._faces = body_dict["f"]

self.num_joints: int = self._weights.shape[-1]
self.num_betas: int = self._shapedirs.shape[-1]
self.J_regressor = body_dict["J_regressor"]
self.weights = body_dict["weights"]
self.v_template = body_dict["v_template"]
self.posedirs = body_dict["posedirs"]
self.shapedirs = body_dict["shapedirs"]
self.faces = body_dict["f"]

self.num_joints: int = self.weights.shape[-1]
self.num_betas: int = self.shapedirs.shape[-1]
self.parent_idx: np.ndarray = body_dict["kintree_table"][0]

def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
# Get shaped vertices + joint positions, when all local poses are identity.
v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)
v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)

# Local SE(3) transforms.
T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
Expand All @@ -63,13 +63,13 @@ def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutpu

# Linear blend skinning.
pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta)
v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
v_posed = np.einsum(
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta
)
return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)
return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint)


def main(model_path: Path) -> None:
Expand All @@ -86,6 +86,13 @@ def main(model_path: Path) -> None:
num_joints=model.num_joints,
parent_idx=model.parent_idx,
)
body_handle = server.scene.add_mesh_simple(
"/human",
model.v_template,
model.faces,
wireframe=gui_elements.gui_wireframe.value,
color=gui_elements.gui_rgb.value,
)
while True:
# Do nothing if no change.
time.sleep(0.02)
Expand All @@ -94,21 +101,20 @@ def main(model_path: Path) -> None:

gui_elements.changed = False

# Compute SMPL outputs.
# If anything has changed, re-compute SMPL outputs.
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=tf.SO3.exp(
# (num_joints, 3)
np.array([x.value for x in gui_elements.gui_joints])
).as_matrix(),
)
server.scene.add_mesh_simple(
"/human",
smpl_outputs.vertices,
smpl_outputs.faces,
wireframe=gui_elements.gui_wireframe.value,
color=gui_elements.gui_rgb.value,
)

# Update the mesh properties based on the SMPL model output + GUI
# elements.
body_handle.vertices = smpl_outputs.vertices
body_handle.wireframe = gui_elements.gui_wireframe.value
body_handle.color = gui_elements.gui_rgb.value

# Match transform control gizmos to joint positions.
for i, control in enumerate(gui_elements.transform_controls):
Expand Down Expand Up @@ -146,7 +152,7 @@ def set_changed(_) -> None:
with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
gui_show_controls = server.gui.add_checkbox("Handles", initial_value=False)
gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)

gui_rgb.on_update(set_changed)
gui_wireframe.on_update(set_changed)
Expand Down
86 changes: 50 additions & 36 deletions examples/25_smpl_visualizer_skinned.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,27 @@ def __init__(self, model_path: Path) -> None:
assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
body_dict = dict(**np.load(model_path, allow_pickle=True))

self._J_regressor = body_dict["J_regressor"]
self._weights = body_dict["weights"]
self._v_template = body_dict["v_template"]
self._posedirs = body_dict["posedirs"]
self._shapedirs = body_dict["shapedirs"]
self._faces = body_dict["f"]

self.num_joints: int = self._weights.shape[-1]
self.num_betas: int = self._shapedirs.shape[-1]
self.J_regressor = body_dict["J_regressor"]
self.weights = body_dict["weights"]
self.v_template = body_dict["v_template"]
self.posedirs = body_dict["posedirs"]
self.shapedirs = body_dict["shapedirs"]
self.faces = body_dict["f"]

self.num_joints: int = self.weights.shape[-1]
self.num_betas: int = self.shapedirs.shape[-1]
self.parent_idx: np.ndarray = body_dict["kintree_table"][0]

def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
# Get shaped vertices + joint positions, when all local poses are identity.
v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
return v_tpose, j_tpose

def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
# Get shaped vertices + joint positions, when all local poses are identity.
v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)
v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)

# Local SE(3) transforms.
T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
Expand All @@ -69,13 +75,13 @@ def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutpu

# Linear blend skinning.
pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta)
v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
v_posed = np.einsum(
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta
)
return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)
return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint)


def main(model_path: Path) -> None:
Expand All @@ -92,23 +98,14 @@ def main(model_path: Path) -> None:
num_joints=model.num_joints,
parent_idx=model.parent_idx,
)
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=np.zeros((model.num_joints, 3, 3)) + np.eye(3),
)

bone_wxyzs = np.array(
[tf.SO3.from_matrix(R).wxyz for R in smpl_outputs.T_world_joint[:, :3, :3]]
)
bone_positions = smpl_outputs.T_world_joint[:, :3, 3]

skinned_handle = server.scene.add_mesh_skinned(
v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
mesh_handle = server.scene.add_mesh_skinned(
"/human",
smpl_outputs.vertices,
smpl_outputs.faces,
bone_wxyzs=bone_wxyzs,
bone_positions=bone_positions,
skin_weights=model._weights,
v_tpose,
model.faces,
bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
bone_positions=j_tpose,
skin_weights=model.weights,
wireframe=gui_elements.gui_wireframe.value,
color=gui_elements.gui_rgb.value,
)
Expand All @@ -119,10 +116,19 @@ def main(model_path: Path) -> None:
if not gui_elements.changed:
continue

# Shapes changed: update vertices / joint positions.
if gui_elements.betas_changed:
v_tpose, j_tpose = model.get_tpose(
np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
)
mesh_handle.vertices = v_tpose
mesh_handle.bone_positions = j_tpose

gui_elements.changed = False
gui_elements.betas_changed = False

# Render as wireframe?
skinned_handle.wireframe = gui_elements.gui_wireframe.value
mesh_handle.wireframe = gui_elements.gui_wireframe.value

# Compute SMPL outputs.
smpl_outputs = model.get_outputs(
Expand All @@ -139,10 +145,10 @@ def main(model_path: Path) -> None:
# Match transform control gizmos to joint positions.
for i, control in enumerate(gui_elements.transform_controls):
control.position = smpl_outputs.T_parent_joint[i, :3, 3]
skinned_handle.bones[i].wxyz = tf.SO3.from_matrix(
mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
smpl_outputs.T_world_joint[i, :3, :3]
).wxyz
skinned_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]


@dataclass
Expand All @@ -156,7 +162,10 @@ class GuiElements:
transform_controls: List[viser.TransformControlsHandle]

changed: bool
"""This flag will be flipped to True whenever the mesh needs to be re-generated."""
"""This flag will be flipped to True whenever any input is changed."""

betas_changed: bool
"""This flag will be flipped to True whenever the shape changes."""


def make_gui_elements(
Expand All @@ -170,7 +179,11 @@ def make_gui_elements(
tab_group = server.gui.add_tab_group()

def set_changed(_) -> None:
out.changed = True # out is define later!
out.changed = True # out is defined later!

def set_betas_changed(_) -> None:
out.betas_changed = True
out.changed = True

# GUI elements: mesh settings + visibility.
with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
Expand Down Expand Up @@ -220,7 +233,7 @@ def _(_):
f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
)
gui_betas.append(beta)
beta.on_update(set_changed)
beta.on_update(set_betas_changed)

# GUI elements: joint angles.
with tab_group.add_tab("Joints", viser.Icon.ANGLE):
Expand Down Expand Up @@ -295,6 +308,7 @@ def _(_) -> None:
gui_joints,
transform_controls=transform_controls,
changed=True,
betas_changed=False,
)
return out

Expand Down
11 changes: 7 additions & 4 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,20 @@ class SkinnedMeshProps(MeshProps):
Vertices are internally canonicalized to float32, faces to uint32."""

bone_wxyzs: Tuple[Tuple[float, float, float, float], ...]
"""Tuple of quaternions representing bone orientations. Synchronized automatically when assigned."""
bone_positions: Tuple[Tuple[float, float, float], ...]
"""Tuple of positions representing bone positions. Synchronized automatically when assigned."""
bone_wxyzs: npt.NDArray[np.float32]
"""Array of quaternions representing bone orientations (B, 4). Synchronized automatically when assigned."""
bone_positions: npt.NDArray[np.float32]
"""Array of positions representing bone positions (B, 3). Synchronized automatically when assigned."""
skin_indices: npt.NDArray[np.uint16]
"""Array of skin indices. Should have shape (V, 4). Synchronized automatically when assigned."""
skin_weights: npt.NDArray[np.float32]
"""Array of skin weights. Should have shape (V, 4). Synchronized automatically when assigned."""

def __post_init__(self):
# Check shapes.
assert self.bone_wxyzs.shape[-1] == 4
assert self.bone_positions.shape[-1] == 3
assert self.bone_wxyzs.shape[0] == self.bone_positions.shape[0]
assert self.vertices.shape[-1] == 3
assert self.faces.shape[-1] == 3
assert self.skin_weights is not None
Expand Down
15 changes: 2 additions & 13 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,19 +1065,8 @@ def add_mesh_skinned(
flat_shading=flat_shading,
side=side,
material=material,
bone_wxyzs=tuple(
(
float(wxyz[0]),
float(wxyz[1]),
float(wxyz[2]),
float(wxyz[3]),
)
for wxyz in bone_wxyzs.astype(np.float32)
),
bone_positions=tuple(
(float(xyz[0]), float(xyz[1]), float(xyz[2]))
for xyz in bone_positions.astype(np.float32)
),
bone_wxyzs=bone_wxyzs.astype(np.float32),
bone_positions=bone_positions.astype(np.float32),
skin_indices=top4_skin_indices.astype(np.uint16),
skin_weights=top4_skin_weights.astype(np.float32),
),
Expand Down
9 changes: 7 additions & 2 deletions src/viser/_scene_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,13 @@ def __setattr__(self, name: str, value: Any) -> None:
elif hint == onpt.NDArray[np.uint8] and "color" in name:
value = colors_to_uint8(value)

if getattr(handle._impl.props, name) == value:
# Do nothing. Assumes equality is defined for the prop value.
current_value = getattr(handle._impl.props, name)

# Do nothing if the value hasn't changed.
if isinstance(current_value, np.ndarray):
if current_value.data == value.data:
return
elif current_value == value:
return

setattr(handle._impl.props, name, value)
Expand Down
30 changes: 26 additions & 4 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,34 @@ function useMessageHandler() {
initialized: false,
poses: [],
};

const bone_wxyzs = new Float32Array(
message.props.bone_wxyzs.buffer.slice(
message.props.bone_wxyzs.byteOffset,
message.props.bone_wxyzs.byteOffset +
message.props.bone_wxyzs.byteLength,
),
);
const bone_positions = new Float32Array(
message.props.bone_positions.buffer.slice(
message.props.bone_positions.byteOffset,
message.props.bone_positions.byteOffset +
message.props.bone_positions.byteLength,
),
);
for (let i = 0; i < message.props.bone_wxyzs!.length; i++) {
const wxyz = message.props.bone_wxyzs[i];
const position = message.props.bone_positions[i];
viewer.skinnedMeshState.current[message.name].poses.push({
wxyz: wxyz,
position: position,
wxyz: [
bone_wxyzs[4 * i],
bone_wxyzs[4 * i + 1],
bone_wxyzs[4 * i + 2],
bone_wxyzs[4 * i + 3],
],
position: [
bone_positions[3 * i],
bone_positions[3 * i + 1],
bone_positions[3 * i + 2],
],
});
}
}
Expand Down
Loading

0 comments on commit d72e85b

Please sign in to comment.