Skip to content

Commit

Permalink
ENH: Updates for multistage registration
Browse files Browse the repository at this point in the history
Updates registration scripts to support multistage registration
procedures:
- Accept an initial transform describing results from previous
  registration stages with the newly introduced Elastix
  `ExternalInitialTransform`
- Update output transform composition to add the initial ITK transform
  • Loading branch information
tbirdso committed Aug 18, 2023
1 parent 756e846 commit 0dc137a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dynamic = ["version", "readme"]

dependencies = [
'itk',
'itk-elastix',
'itk-elastix>=0.18.0',
'itk-genericlabelinterpolator',
'numpy',
'pandas',
Expand Down
31 changes: 29 additions & 2 deletions src/aind_ccf_alignment_experiments/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,37 @@ def get_physical_midpoint(
# Image block streaming helpers
###############################################################################

# Terms:
# - "block": A representation in voxel space with integer image access.
# - "physical": A representation in physical space with 3D floating-point representation.
#
# - "block region": a 2x3 voxel array representing upper and lower voxel bounds
# in ITK access order.
# If "k" is fastest and "i" is slowest:
# [ [ lower_k, lower_j, lower_i ]
# upper_k, upper_j, upper_i ] ]
#
# - "physical region": a 2x3 voxel array representing inclusive upper and lower bounds in
# physical space.
# [ [ lower_x, lower_y, lower_z ]
# upper_x, upper_y, upper_z ] ]
#
# - "ITK region": an `itk.ImageRegion[3]` representation of a block region.
# itk.ImageRegion[3]( [ [lower_k, lower_j, lower_i], [size_k, size_j, size_i] ])


def block_to_physical_size(
block_size: npt.ArrayLike,
ref_image: itk.Image,
transform: itk.Transform = None,
) -> npt.ArrayLike:
"""Convert from voxel block size to corresponding size in physical space"""
"""
Convert from voxel block size to corresponding size in physical space.
Naive transform approach assumes that both the input and output regions
are constrained along x/y/z planes aligned at two point extremes.
May not be suitable for deformable regions.
"""
block_index = [int(x) for x in block_size]

if transform:
Expand Down Expand Up @@ -198,6 +222,7 @@ def get_target_block_region(
block_region: npt.ArrayLike,
src_image: itk.Image,
target_image: itk.Image,
src_transform: itk.Transform = None,
crop_to_target: bool = False,
) -> npt.ArrayLike:
"""
Expand All @@ -206,7 +231,9 @@ def get_target_block_region(
"""
target_region = physical_to_block_region(
physical_region=block_to_physical_region(
block_region=block_region, ref_image=src_image
block_region=block_region,
ref_image=src_image,
transform=src_transform,
),
ref_image=target_image,
)
Expand Down
49 changes: 49 additions & 0 deletions src/aind_ccf_alignment_experiments/registration_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def get_elx_itk_transforms(
for transform_index, itk_transform_type in enumerate(
itk_transform_types
):
if not itk_transform_type: # skip on None
continue

elx_transform = registration_method.GetNthTransform(
transform_index
)
Expand Down Expand Up @@ -110,6 +113,7 @@ def register_elastix(
source_image: itk.Image,
target_image: itk.Image,
parameter_object: itk.ParameterObject = None,
initial_transform: itk.Transform = None,
itk_transform_types: List[type] = None,
log_filepath: str = None,
verbose: bool = False,
Expand All @@ -134,6 +138,10 @@ def register_elastix(
itk.AffineTransform[itk.D, DIMENSION],
itk.Euler3DTransform[itk.D],
]
if initial_transform and itk_transform_types[-1]:
# Cannot directly convert an external init ITK transfrom from Elastix
itk_transform_types.append(None)

if log_filepath:
os.makedirs(os.path.dirname(log_filepath), exist_ok=True)

Expand All @@ -148,6 +156,9 @@ def register_elastix(
parameter_object=parameter_object,
)

if initial_transform:
registration_method.SetExternalInitialTransform(initial_transform)

if log_filepath:
registration_method.SetLogToFile(True)
registration_method.SetOutputDirectory(os.path.dirname(log_filepath))
Expand All @@ -159,6 +170,11 @@ def register_elastix(
itk_composite_transform = get_elx_itk_transforms(
registration_method, itk_transform_types
)
if initial_transform:
itk_composite_transform.AppendTransform(initial_transform)
itk_composite_transform = flatten_composite_transform(
itk_composite_transform
)

return (
itk_composite_transform,
Expand Down Expand Up @@ -278,3 +294,36 @@ def cast_vector_image_to_double(vector_image: itk.Image) -> itk.Image:
composite_transform.AddTransform(displacement_transform)

return composite_transform, ants_result


def flatten_composite_transform(
transform: itk.Transform,
) -> itk.CompositeTransform[itk.D, 3]:
"""
Recursively flatten an `itk.CompositeTransform` that may contain
`itk.CompositeTransform` members so that the output represents a
single layer of non-composite transforms.
"""
inner_transforms = _flatten_composite_transform_recursive(transform)

output_transform = itk.CompositeTransform[itk.D, 3].New()
for transform in inner_transforms:
output_transform.AppendTransform(transform)
return output_transform


def _flatten_composite_transform_recursive(
transform: itk.Transform,
) -> List[itk.Transform]:
t = None
try:
t = itk.CompositeTransform[itk.D, 3].cast(transform)
except RuntimeError as e:
return [transform]

transform_list = []
for index in range(t.GetNumberOfTransforms()):
transform_list.append(
*_flatten_composite_transform_recursive(t.GetNthTransform(index))
)
return transform_list

0 comments on commit 0dc137a

Please sign in to comment.