From a82eb344cad7086af6dc14672c8c9e2061bac624 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sat, 2 Sep 2023 14:17:56 -0700 Subject: [PATCH] Appease pyright --- viser/_message_api.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/viser/_message_api.py b/viser/_message_api.py index d2a38823a..e2409d5cf 100644 --- a/viser/_message_api.py +++ b/viser/_message_api.py @@ -215,9 +215,10 @@ def add_spline_catmull_rom( ) -> None: """Add spline using Catmull-Rom interpolation.""" if isinstance(positions, onp.ndarray): - assert len(positions.shape) == 2 - positions = tuple(map(tuple, positions)) - + assert len(positions.shape) == 2 and positions.shape[1] == 3 + positions = tuple(map(tuple, positions)) # type: ignore + assert len(positions[0]) == 3 + assert isinstance(positions, tuple) self._queue( _messages.CatmullRomSplineMessage( name, @@ -241,12 +242,14 @@ def add_spline_cubic_bezier( """Add spline using Cubic Bezier interpolation.""" if isinstance(positions, onp.ndarray): - assert len(positions.shape) == 2 - positions = tuple(map(tuple, positions)) + assert len(positions.shape) == 2 and positions.shape[1] == 3 + positions = tuple(map(tuple, positions)) # type: ignore if isinstance(control_points, onp.ndarray): - assert len(control_points.shape) == 2 - control_points = tuple(map(tuple, control_points)) + assert len(control_points.shape) == 2 and control_points.shape[1] == 3 + control_points = tuple(map(tuple, control_points)) # type: ignore + assert isinstance(positions, tuple) + assert isinstance(control_points, tuple) assert len(control_points) == (2 * len(positions) - 2) self._queue( _messages.CubicBezierSplineMessage(