Skip to content

Commit

Permalink
Add FormatESSNMX to read data from the NMX detector at ESS.
Browse files Browse the repository at this point in the history
  • Loading branch information
toastisme committed Oct 8, 2024
1 parent 898072e commit 8f4fdf8
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 0 deletions.
1 change: 1 addition & 0 deletions newsfragments/XXX.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add format class to read data from the NMX ESS detector.
301 changes: 301 additions & 0 deletions src/dxtbx/format/FormatESSNMX.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
from __future__ import annotations

import logging
from typing import List, Tuple

import h5py
import numpy as np

import cctbx.array_family.flex as flex

import dxtbx_flumpy as flumpy
from dxtbx import IncorrectFormatError
from dxtbx.format.FormatHDF5 import FormatHDF5
from dxtbx.model import Detector
from dxtbx.model.beam import BeamFactory, PolychromaticBeam, Probe
from dxtbx.model.goniometer import Goniometer, GoniometerFactory
from dxtbx.model.scan import Scan, ScanFactory

logger = logging.getLogger(__name__)


class FormatESSNMX(FormatHDF5):
"""
Class to read files from NMX
https://europeanspallationsource.se/instruments/nmx
preprocessed files in scipp to obtain binned data
"""

def __init__(self, image_file, **kwargs) -> None:
if not FormatESSNMX.understand(image_file):
raise IncorrectFormatError(self, image_file)
self._nxs_file = h5py.File(image_file, "r")
self._raw_data = None

@staticmethod
def understand(image_file: str) -> bool:
try:
return FormatESSNMX.is_nmx_file(image_file)
except (IOError, KeyError):
return False

@staticmethod
def is_nmx_file(image_file: str) -> bool:
def get_name(image_file):
try:
with h5py.File(image_file, "r") as handle:
return handle["NMX_data"].attrs["name"]
except (IOError, KeyError, AttributeError):
return ""

return get_name(image_file) == "NMX"

def get_instrument_name(self) -> str:
return "NMX"

def get_experiment_description(self) -> str:
return "Simulated data"

def _load_raw_data(self) -> None:
raw_data = []
image_size = self._get_image_size()
total_pixels = image_size[0] * image_size[1]
num_images = self.get_num_images()
for i in range(self._get_num_panels()):
spectra = self._nxs_file["NMX_data"]["detector_1"]["counts"][
0, total_pixels * i : total_pixels * (i + 1), :
]
spectra = np.reshape(spectra, (image_size[0], image_size[1], num_images))
raw_data.append(flumpy.from_numpy(spectra))

self._raw_data = tuple(raw_data)

def get_raw_data(
self, index: int, use_loaded_data: bool = False
) -> Tuple[flex.int]:
raw_data = []
image_size = self._get_image_size()
total_pixels = image_size[0] * image_size[1]

if use_loaded_data:
if self._raw_data is None:
self._load_raw_data()
for panel in self._raw_data:
data = panel[:, :, index : index + 1]
data.reshape(flex.grid(panel.all()[0], panel.all()[1]))
data.matrix_transpose_in_place()
raw_data.append(data)

else:
for i in range(self._get_num_panels()):
spectra = self._nxs_file["NMX_data"]["detector_1"]["counts"][
0, total_pixels * i : total_pixels * (i + 1), index : index + 1
]
spectra = np.reshape(spectra, image_size)
raw_data.append(flumpy.from_numpy(np.ascontiguousarray(spectra)))

return tuple(raw_data)

def _get_time_channel_bins(self) -> List[float]:
# (usec)
return self._nxs_file["NMX_data"]["detector_1"]["t_bin"][:] * 10**6

def _get_time_of_flight(self) -> List[float]:
# (usec)
bins = self._get_time_channel_bins()
return [float((bins[i] + bins[i + 1]) * 0.5) for i in range(len(bins) - 1)]

def get_num_images(self) -> int:
return len(self._get_time_of_flight())

def get_detector(self, index: int = None) -> Detector:
num_panels = self._get_num_panels()
panel_names = self._get_panel_names()
panel_type = self._get_panel_type()
image_size = self._get_image_size()
trusted_range = self._get_panel_trusted_range()
pixel_size = self._get_pixel_size()
fast_axes = self._get_panel_fast_axes()
slow_axes = self._get_panel_slow_axes()
panel_origins = self._get_panel_origins()
gain = self._get_panel_gain()
panel_projections = self._get_panel_projections_2d()
detector = Detector()
root = detector.hierarchy()

for i in range(num_panels):
panel = root.add_panel()
panel.set_type(panel_type)
panel.set_name(panel_names[i])
panel.set_image_size(image_size)
panel.set_trusted_range(trusted_range)
panel.set_pixel_size(pixel_size)
panel.set_local_frame(fast_axes[i], slow_axes[i], panel_origins[i])
panel.set_gain(gain)
r, t = panel_projections[i]
r = tuple(map(int, r))
t = tuple(map(int, t))
panel.set_projection_2d(r, t)

return detector

def _get_num_panels(self) -> int:
return self._nxs_file["NMX_data/instrument"].attrs["nr_detector"]

def _get_panel_names(self) -> List[str]:
return [
"%02d" % (i + 1)
for i in range(self._nxs_file["NMX_data/instrument"].attrs["nr_detector"])
]

def _get_panel_type(self) -> str:
return "SENSOR_PAD"

def _get_image_size(self) -> Tuple[int, int]:
# (px)
return (1280, 1280)

def _get_panel_trusted_range(self) -> Tuple[int, int]:
# 4 * 1280**2 plus buffer
return (-1, 7000000)

def _get_pixel_size(self) -> Tuple[float, float]:
# (mm)
return (0.4, 0.4)

def _get_panel_fast_axes(self) -> Tuple[Tuple[float, float, float]]:
return ((1.0, 0.0, 0.0), (0.0, 0.0, 1.0), (0.0, 0.0, -1.0))

def _get_panel_slow_axes(self) -> Tuple[Tuple[float, float, float]]:
return ((0.0, 1.0, 0.0), (0.0, 1.0, 0.0), (0.0, 1.0, 0.0))

def _get_panel_origins(self) -> Tuple[Tuple[float, float, float]]:
# (mm)
return ((-250, -250.0, -292.0), (290, -250.0, -250), (-290, -250.0, 250.0))

def _get_panel_projections_2d(self) -> dict[int : Tuple[Tuple, Tuple]]:
p_w, p_h = self._get_image_size()
p_w += 10
p_h += 10
panel_pos = {
0: ((-1, 0, 0, -1), (p_h, 0)),
1: ((-1, 0, 0, -1), (p_h, p_w)),
2: ((-1, 0, 0, -1), (p_h, -p_w)),
}

return panel_pos

def get_beam(self, index: int = None) -> PolychromaticBeam:
direction = self._get_sample_to_source_direction()
distance = self._get_sample_to_source_distance()
wavelength_range = self._get_wavelength_range()
return BeamFactory.make_polychromatic_beam(
direction=direction,
sample_to_source_distance=distance,
probe=Probe.neutron,
wavelength_range=wavelength_range,
)

def _get_sample_to_source_direction(self) -> Tuple[float, float, float]:
return (0, 0, 1)

def _get_wavelength_range(self) -> Tuple[float, float]:
# (A)
return (1.8, 2.55)

def _get_sample_to_source_distance(self) -> float:
try:
dist = abs(self._nxs_file["NMX_data/NXsource/distance"][...]) * 1000
return dist
except (KeyError, ValueError):
logger.warning("sample to moderator_distance not found, using dummy value")
return 157406

def _get_panel_gain(self) -> float:
return 1.0

def get_goniometer_phi_angle(self) -> float:
return self.get_goniometer_orientations()[0]

def get_goniometer(self, index: int = None) -> Goniometer:
rotation_axis = (0.0, 1.0, 0.0)
fixed_rotation = (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
goniometer = GoniometerFactory.make_goniometer(rotation_axis, fixed_rotation)
try:
angles = self.get_goniometer_orientations()
except KeyError:
return goniometer
axes = ((1, 0, 0), (0, 1, 0), (0, 0, 1))
for idx, angle in enumerate(angles):
goniometer.rotate_around_origin(axes[idx], -angle)
return goniometer

def get_goniometer_orientations(self) -> Tuple[float, float, float]:
# Angles in deg along x, y, z
return self._nxs_file["NMX_data/crystal_orientation"][...]

def get_scan(self, index=None) -> Scan:
image_range = (1, self.get_num_images())
properties = {"time_of_flight": tuple(self._get_time_of_flight())}
return ScanFactory.make_scan_from_properties(
image_range=image_range, properties=properties
)

def get_flattened_data(
self, image_range: None | Tuple = None, scale_data: bool = True
) -> Tuple[flex.int]:
"""
Image data summed along the time-of-flight direction
"""

panel_size = self._get_image_size()
total_pixels = panel_size[0] * panel_size[1]
max_val = None
num_tof_bins = len(self._get_time_channel_bins()) - 1
raw_data = []
for panel_idx in range(self._get_num_panels()):
panel_data = self._nxs_file["NMX_data"]["detector_1"]["counts"][
0, total_pixels * panel_idx : total_pixels * (panel_idx + 1), :
]
panel_data = np.reshape(
panel_data, (panel_size[0], panel_size[1], num_tof_bins)
)
if image_range is not None:
assert (
len(image_range) == 2
), "expected image_range to be only two values"
assert (
image_range[0] >= 0 and image_range[0] < image_range[1]
), "image_range[0] out of range"
assert image_range[1] <= num_tof_bins, "image_range[1] out of range"
panel_data = np.sum(
panel_data[:, :, image_range[0] : image_range[1]], axis=2
).T

else:
panel_data = np.sum(panel_data, axis=2).T
panel_max_val = np.max(panel_data)
if max_val is None or max_val < panel_max_val:
max_val = panel_max_val
raw_data.append(panel_data)

if scale_data:
return tuple([(i / max_val).tolist() for i in raw_data])

return tuple([i.tolist() for i in raw_data])

def get_flattened_pixel_data(
self, panel_idx: int, x: int, y: int
) -> Tuple[Tuple, Tuple]:
time_channels = self._get_time_of_flight()
panel_size = self._get_image_size()
height = panel_size[1]
total_pixels = panel_size[0] * panel_size[1]
idx = (panel_idx * total_pixels) + panel_idx + x * height + y
return (
time_channels,
tuple(self._nxs_file["NMX_data/detector_1/counts"][0, idx, :].tolist()),
)

def get_proton_charge(self) -> float:
return self._nxs_file["NMX_data"]["proton_charge"][...]

0 comments on commit 8f4fdf8

Please sign in to comment.