Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'Lazy' loading of Dask arrays from Tiled (SRX) #321

Merged
merged 3 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pyxrf/model/lineplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from atom.api import Atom, Bool, Dict, Float, Int, List, Str, Typed, observe
from matplotlib.axes import Axes
from matplotlib.collections import BrokenBarHCollection
from matplotlib.collections import PolyCollection
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -122,7 +122,7 @@ class LinePlotModel(Atom):
_fig_preview = Typed(Figure)
_ax_preview = Typed(Axes)
_lines_preview = List()
_bahr_preview = Typed(BrokenBarHCollection)
_bahr_preview = Typed(PolyCollection)

plot_type_preview = Typed(PlotTypes)
energy_range_preview = Typed(EnergyRangePresets)
Expand Down Expand Up @@ -177,7 +177,7 @@ class LinePlotModel(Atom):
show_exp_opt = Bool(False) # Flag: show spectrum preview

# Reference to artist responsible for displaying the selected range of energies on the plot
plot_energy_barh = Typed(BrokenBarHCollection)
plot_energy_barh = Typed(PolyCollection)
t_bar = Typed(object)

plot_exp_list = List()
Expand Down Expand Up @@ -727,7 +727,7 @@ def plot_selected_energy_range_original(self, *, e_low=None, e_high=None):
self.plot_energy_barh.remove()

# Create the new plot (based on new parameters if necessary
self.plot_energy_barh = BrokenBarHCollection.span_where(
self.plot_energy_barh = PolyCollection.span_where(
x_v, ymin=y_min, ymax=y_max, where=ss, facecolor="white", edgecolor="yellow", alpha=1
)
self._ax.add_collection(self.plot_energy_barh)
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def plot_selected_energy_range(self, *, axes, barh_existing, e_low=None, e_high=
barh_existing.remove()

# Create the new plot (based on new parameters if necessary
barh_new = BrokenBarHCollection.span_where(
barh_new = PolyCollection.span_where(
x_v, ymin=y_min, ymax=y_max, where=ss, facecolor="white", edgecolor="yellow", alpha=1
)
axes.add_collection(barh_new)
Expand Down
84 changes: 68 additions & 16 deletions pyxrf/model/load_data_from_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,20 +2318,24 @@ def map_data2D_srx_new_tiled(

d_xs, d_xs_sum, N_xs, d_xs2, d_xs2_sum, N_xs2 = None, None, 0, None, None, 0
if "xs_fluor" in data_stream0:
d_xs = data_stream0["xs_fluor"]
# The type of loaded data is tiled.client.array.DaskArrayClient
# If the data is not explicitly converted to da.array, then da.sum
# will fill the whole dataset, which is not desirable. This could be
# fixed in the future versions of Tiled.
d_xs = da.array(data_stream0["xs_fluor"])
d_xs_sum = da.sum(d_xs, 2)
N_xs = d_xs.shape[2]
elif "fluor" in data_stream0: # Old format
d_xs = data_stream0["fluor"]
d_xs = da.array(data_stream0["fluor"])
d_xs_sum = da.sum(d_xs, 2)
N_xs = d_xs.shape[2]

if "xs_fluor_xs2" in data_stream0:
d_xs2 = data_stream0["xs_fluor_xs2"]
d_xs2 = da.array(data_stream0["xs_fluor_xs2"])
d_xs2_sum = da.sum(d_xs2, 2)
N_xs2 = d_xs2.shape[2]
elif "fluor_xs2" in data_stream0: # Old format
d_xs2 = data_stream0["fluor_xs2"]
d_xs2 = da.array(data_stream0["fluor_xs2"])
d_xs2_sum = da.sum(d_xs2, 2)
N_xs2 = d_xs2.shape[2]

Expand Down Expand Up @@ -2536,10 +2540,12 @@ def swap_axes():

# Replace NaNs with 0s (in corrupt data rows).
loaded_data = {}
loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute())
# loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute())
loaded_data["det_sum"] = tmp_data_sum
if create_each_det:
for i in range(num_det):
loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute())
# loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute())
loaded_data["det" + str(i + 1)] = da.squeeze(tmp_data[:, :, i, :])

if save_scaler:
loaded_data["scaler_data"] = sclr.compute()
Expand Down Expand Up @@ -3519,6 +3525,8 @@ def save_data_to_hdf5(
Failed to write data to HDF5 file.
"""

time_start = ttime.time()

fpath = os.path.expanduser(fpath)
fpath = os.path.abspath(fpath)

Expand All @@ -3539,22 +3547,23 @@ def incorrect_type_msg(channel, data_type):
f"The data is converted from '{data_type}' to 'np.float32' before saving to file."
)

if "det_sum" in data and isinstance(data["det_sum"], np.ndarray):
if "det_sum" in data and isinstance(data["det_sum"], (np.ndarray, da.core.Array)):
if data["det_sum"].dtype != np.float32:
incorrect_type_msg("det_sum", data["det_sum"].dtype)
data["det_sum"] = data["det_sum"].astype(np.float32, copy=False)
sum_data = data["det_sum"]
sum_data_exists = True

for detname in xrf_det_list:
if detname in data and isinstance(data[detname], np.ndarray):
if detname in data and isinstance(data[detname], (np.ndarray, da.core.Array)):
if data[detname].dtype != np.float32:
incorrect_type_msg(detname, data[detname].dtype)
data[detname] = data[detname].astype(np.float32, copy=False)

if not sum_data_exists: # Don't compute it if it already exists
if sum_data is None:
sum_data = np.copy(data[detname])
# sum_data = np.copy(data[detname])
sum_data = data[detname].copy()
else:
sum_data += data[detname]

Expand Down Expand Up @@ -3598,18 +3607,59 @@ def incorrect_type_msg(channel, data_type):
for key, value in metadata_prepared.items():
metadata_grp.attrs[key] = value

# The following parameters control how data is loaded from Tiled
n_pixels_in_batch = 40000
n_tiled_download_retries = 10

def compute_batch_params(shape):
n_rows, n_cols = shape[0], shape[1]
n_rows_batch = max(int(n_pixels_in_batch / n_cols), 1) # Save at least one row
n_batches = int(n_rows / n_rows_batch)
if n_rows % n_rows_batch:
n_batches += 1
return n_rows, n_rows_batch, n_batches

def download_dataset(dset, data):
n_rows, n_rows_batch, n_batches = compute_batch_params(sum_data.shape)
for n in range(n_batches):
ns, ne = n * n_rows_batch, min((n + 1) * n_rows_batch, n_rows)
for retry in range(n_tiled_download_retries):
try:
dset[ns:ne, ...] = np.nan_to_num(np.array(data[ns:ne, ...]))
break
except Exception as ex:
logger.error(f"Failed to load the batch: {ex}")
if retry >= n_tiled_download_retries - 1:
raise TimeoutError("Failed to download data from Tiled server")
print(f" Number of saved rows: {ne}")

if create_each_det is True:
for detname in xrf_det_list:
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)
if not isinstance(sum_data, da.core.Array):
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)
else:
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", new_data.shape, compression="gzip")
print(f"Downloading data: channel {detname!r} ...")
download_dataset(ds_data, new_data)
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)

# summed data
if sum_data is not None:
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from channel sum"
if not isinstance(sum_data, da.core.Array):
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from channel sum"
else:
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", sum_data.shape, compression="gzip")
print("Downloading data: the sum of all channels ...")
download_dataset(ds_data, sum_data)
ds_data.attrs["comments"] = "Experimental data from channel sum"

# add positions
if "pos_names" in data:
Expand All @@ -3627,6 +3677,8 @@ def incorrect_type_msg(channel, data_type):
dataGrp.create_dataset("name", data=helper_encode_list(scaler_names))
dataGrp.create_dataset("val", data=scaler_data)

logger.info(f"Total data saving time: {ttime.time() - time_start}")

return fpath


Expand Down
Loading