Skip to content

Commit

Permalink
feat(morph-plot): fix closed cylinder generation
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjayankur31 committed Sep 4, 2024
1 parent 7af3c82 commit a3ee75f
Showing 1 changed file with 60 additions and 38 deletions.
98 changes: 60 additions & 38 deletions pyneuroml/plot/PlotMorphologyVispy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,64 @@ def create_spherical_mesh(
)


@lru_cache(maxsize=100)
def compute_faces_of_cylindrical_mesh(rows: int, cols: int, closed: bool):
"""Compute faces for cylindrical meshes
Since we tend to use a constant set of rows and cols, this function should
be called repeatedly with the same values, and thus benefits from caching.
:param rows: number of rows in mesh
:type rows: int
:param cols: number of cols in mesh
:type cols: int
:param closed: toggle whether mesh is closed or open
:type closed: bool
:returns: numpy array with faces (triplets of vertex indices)
"""
faces = numpy.empty((rows * cols * 2, 3), dtype=numpy.uint32)
rowtemplate1 = (
(numpy.arange(cols).reshape(cols, 1) + numpy.array([[0, 1, 0]])) % cols
) + numpy.array([[0, 0, cols]])
logger.debug(f"Template1 is: {rowtemplate1}")

rowtemplate2 = (
(numpy.arange(cols).reshape(cols, 1) + numpy.array([[0, 1, 1]])) % cols
) + numpy.array([[cols, 0, cols]])
# logger.debug(f"Template2 is: {rowtemplate2}")

for row in range(rows):
start = row * cols * 2
faces[start : start + cols] = rowtemplate1 + row * cols
faces[start + cols : start + (cols * 2)] = rowtemplate2 + row * cols

num_verts = (rows + 1) * cols
# used below:
# index of center of first cap = num_verts
# index of center of second cap = num_verts + 1

# add extra faces to cover the caps
if closed is True:
cap1 = numpy.arange(cols).reshape(cols, 1)
cap1 = numpy.concatenate(
(numpy.full((cols, 1), num_verts), cap1, numpy.roll(cap1, 1)), axis=1
)
logger.debug(f"cap1 is {cap1}")

cap2 = numpy.arange(rows * cols, (rows + 1) * cols).reshape(cols, 1)
cap2 = numpy.concatenate(
(numpy.full((cols, 1), num_verts + 1), cap2, numpy.roll(cap2, 1)), axis=1
)
logger.debug(f"cap2 is {cap2}")

faces = numpy.append(faces, cap1, axis=0)
faces = numpy.append(faces, cap2, axis=0)

logger.debug(f"Faces are: {faces}")
return faces


@lru_cache(maxsize=10000)
def create_cylindrical_mesh(
rows: int,
Expand Down Expand Up @@ -1456,43 +1514,7 @@ def create_cylindrical_mesh(
verts = numpy.append(verts, [[0.0, 0.0, 0.0], [0.0, 0.0, length]], axis=0)
logger.debug(f"Verts are: {verts}")

# compute faces
faces = numpy.empty((rows * cols * 2, 3), dtype=numpy.uint32)
rowtemplate1 = (
(numpy.arange(cols).reshape(cols, 1) + numpy.array([[0, 1, 0]])) % cols
) + numpy.array([[0, 0, cols]])
logger.debug(f"Template1 is: {rowtemplate1}")

rowtemplate2 = (
(numpy.arange(cols).reshape(cols, 1) + numpy.array([[0, 1, 1]])) % cols
) + numpy.array([[cols, 0, cols]])
# logger.debug(f"Template2 is: {rowtemplate2}")

for row in range(rows):
start = row * cols * 2
faces[start : start + cols] = rowtemplate1 + row * cols
faces[start + cols : start + (cols * 2)] = rowtemplate2 + row * cols

# add extra faces to cover the caps
if closed is True:
cap1 = (numpy.arange(cols).reshape(cols, 1) + numpy.array([[0, 0, 1]])) % cols
cap1[..., 0] = len(verts) - 2
logger.debug(f"cap1 is {cap1}")

cap2_start = rows * cols
cap2 = numpy.arange(cols).reshape(cols, 1)
logger.info(f"cap2 is {cap2}")
cap2 = cap2 + numpy.array([[cap2_start, cap2_start, cap2_start + 1]])
logger.info(f"cap2 is {cap2}")
cap2 = cap2 % cap2_start + cap2 * int(cap2 / cap2_start - 1)
logger.info(f"cap2 is {cap2}")
cap2[..., 0] = len(verts) - 1
logger.info(f"cap2 is {cap2}")

faces = numpy.append(faces, cap1, axis=0)
faces = numpy.append(faces, cap2, axis=0)

logger.debug(f"Faces are: {faces}")
faces = compute_faces_of_cylindrical_mesh(rows, cols, closed)

return MeshData(vertices=verts, faces=faces)

Expand Down Expand Up @@ -1568,7 +1590,7 @@ def create_mesh(meshdata, plot_type, current_view, min_width, save_mesh_to):
seg_mesh = create_cylindrical_mesh(
rows=rows, cols=9, radius=(r1, r2), length=length, closed=True
)
logger.info(
logger.debug(
f"Created cylinderical mesh template with radii {r1}, {r2}, {length}"
)

Expand Down

0 comments on commit a3ee75f

Please sign in to comment.