Skip to content

Commit

Permalink
NXmx: handle multidimensional arrays (cctbx#612)
Browse files Browse the repository at this point in the history
Data in NeXus can be 3 or 4 dimensional.
3D: Nimages by slow by fast
4D: Nimages by Nmodules by slow fast

Slice image_size and reshape the raw_data in these cases.

Co-authored-by: Richard Gildea <[email protected]>
  • Loading branch information
2 people authored and toastisme committed Jul 11, 2024
1 parent 8835796 commit 84a8171
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 9 deletions.
1 change: 1 addition & 0 deletions newsfragments/612.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NXmx files with multidimensional arrays (images, modules, or both) are now handled.
28 changes: 19 additions & 9 deletions src/dxtbx/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import itertools
import logging
from typing import Literal, Optional, Tuple, cast
from typing import Literal, Optional

import h5py
import numpy as np
Expand Down Expand Up @@ -404,9 +404,12 @@ def equipment_component_key(dependency):
origin -= nxdetector.beam_center_y.magnitude * pixel_size[1] * slow_axis

# dxtbx requires image size in the order fast, slow - which is the reverse of what
# is stored in module.data_size
image_size = cast(Tuple[int, int], tuple(map(int, module.data_size[::-1])))
assert len(image_size) == 2
# is stored in module.data_size. Additionally, data_size can have more than 2
# dimensions, for multi-module detectors. So take the last two dimensions and reverse
# them. Examples:
# [1,2,3] --> (3, 2)
# [1,2] --> (2, 1)
image_size = (int(module.data_size[-1]), int(module.data_size[-2]))
underload = (
float(nxdetector.underload_value)
if nxdetector.underload_value is not None
Expand Down Expand Up @@ -475,13 +478,17 @@ def get_static_mask(nxdetector: nxmx.NXdetector) -> tuple[flex.bool, ...] | None
pixel_mask = nxdetector.pixel_mask
except KeyError:
return None
if pixel_mask is None or not pixel_mask.size or pixel_mask.ndim != 2:
if pixel_mask is None or not pixel_mask.size:
return None
all_slices = get_detector_module_slices(nxdetector)
return tuple(
flumpy.from_numpy(np.ascontiguousarray(pixel_mask[slices])) == 0
for slices in all_slices
)
all_mask_slices = []
for slices in all_slices:
mask_slice = flumpy.from_numpy(np.ascontiguousarray(pixel_mask[slices])) == 0
mask_slice.reshape(
flex.grid(mask_slice.all()[-2:])
) # handle 3 or 4 dimension arrays
all_mask_slices.append(mask_slice)
return tuple(all_mask_slices)


def _dataset_as_flex(
Expand Down Expand Up @@ -562,5 +569,8 @@ def get_raw_data(
data_as_flex = _dataset_as_flex(
sliced_outer, tuple(module_slices), bit_depth=bit_depth
)
data_as_flex.reshape(
flex.grid(data_as_flex.all()[-2:])
) # handle 3 or 4 dimension arrays
all_data.append(data_as_flex)
return tuple(all_data)
103 changes: 103 additions & 0 deletions tests/nexus/test_build_dxtbx_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,109 @@ def test_get_dxtbx_detector_beam_center_fallback(nxmx_example):
)


@pytest.fixture
def detector_with_multiple_modules():

with h5py.File(" ", "w", **pytest.h5_in_memory) as f:

detector = f.create_group("/entry/instrument/detector")
detector.attrs["NX_class"] = "NXdetector"
detector["beam_center_x"] = 2079.79727597266
detector["beam_center_y"] = 2225.38773853771
detector["count_time"] = 0.00285260857097799
detector["depends_on"] = "/entry/instrument/detector/transformations/det_z"
detector["description"] = "Eiger 16M"
detector["distance"] = 0.237015940260233
detector.create_dataset("data", data=np.zeros((100, 100)))
detector["sensor_material"] = "Silicon"
detector["sensor_thickness"] = 0.00045
detector["sensor_thickness"].attrs["units"] = b"m"
detector["x_pixel_size"] = 7.5e-05
detector["y_pixel_size"] = 7.5e-05
detector["underload_value"] = 0
detector["saturation_value"] = 9266
detector["frame_time"] = 0.1
detector["frame_time"].attrs["units"] = "s"
detector["bit_depth_readout"] = np.array(32)
mask = np.zeros((2, 100, 200), dtype="i8")
detector.create_dataset("pixel_mask", data=mask)

detector_transformations = detector.create_group("transformations")
detector_transformations.attrs["NX_class"] = "NXtransformations"
det_z = detector_transformations.create_dataset("det_z", data=np.array([289.3]))
det_z.attrs["depends_on"] = b"."
det_z.attrs["transformation_type"] = b"translation"
det_z.attrs["units"] = b"mm"
det_z.attrs["vector"] = np.array([0.0, 0.0, 1.0])

def make_module(name, depends_on, data_origin, fast_direction, slow_direction):
module = detector.create_group(name)
module.attrs["NX_class"] = "NXdetector_module"
module.create_dataset("data_size", data=np.array([1, 100, 200]))
module.create_dataset("data_origin", data=np.array(data_origin))
fast = module.create_dataset("fast_pixel_direction", data=0.075)
fast.attrs["transformation_type"] = "translation"
fast.attrs["depends_on"] = depends_on
fast.attrs["vector"] = np.array(fast_direction)
fast.attrs["units"] = "mm"
slow = module.create_dataset("slow_pixel_direction", data=0.075)
slow.attrs["transformation_type"] = "translation"
slow.attrs["depends_on"] = depends_on
slow.attrs["vector"] = np.array(slow_direction)
slow.attrs["units"] = "mm"

make_module(
name="m0",
depends_on="/entry/instrument/detector/transformations/det_z",
data_origin=[0, 0, 0],
fast_direction=[-0.999998, -0.001781, 0],
slow_direction=[-0.001781, 0.999998, 0],
)
make_module(
name="m1",
depends_on="/entry/instrument/detector/transformations/det_z",
data_origin=[1, 0, 0],
fast_direction=[-0.999998, -0.001781, 0],
slow_direction=[-0.001781, 0.999998, 0],
)

nxdata = f.create_group("/entry/data")
nxdata.attrs["NX_class"] = "NXdata"
nxdata.create_dataset(
"data",
data=np.array(
[np.full((2, 100, 200), i, dtype=np.int32) for i in range(3)]
),
)
nxdata.attrs["signal"] = "/entry/data/data"

yield f


def test_get_dxtbx_detector_with_multiple_modules(detector_with_multiple_modules):
det = nxmx.NXdetector(detector_with_multiple_modules["/entry/instrument/detector"])
wavelength = 1

detector = dxtbx.nexus.get_dxtbx_detector(det, wavelength)
assert len(detector) == 2
for panel in detector:
assert panel.get_image_size() == (200, 100)

nxdata = nxmx.NXdata(detector_with_multiple_modules["/entry/data"])
for i in range(3):
raw_data = dxtbx.nexus.get_raw_data(nxdata, det, i)
assert len(raw_data) == 2
for module_data in raw_data:
assert module_data.all() == (100, 200)
assert module_data.all_eq(i)

mask = dxtbx.nexus.get_static_mask(det)
assert len(mask) == 2
for module_mask in mask:
assert isinstance(module_mask, flex.bool)
assert module_mask.all() == (100, 200)


@pytest.fixture
def detector_with_two_theta():
with h5py.File(" ", "w", **pytest.h5_in_memory) as f:
Expand Down

0 comments on commit 84a8171

Please sign in to comment.