Skip to content

Commit

Permalink
Merge pull request #321 from dmgav/lazy-loading
Browse files Browse the repository at this point in the history
'Lazy' loading of Dask arrays from Tiled (SRX)
  • Loading branch information
dmgav authored Jun 9, 2024
2 parents 4419b9b + dcc11c4 commit 0467cb9
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 21 deletions.
10 changes: 5 additions & 5 deletions pyxrf/model/lineplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from atom.api import Atom, Bool, Dict, Float, Int, List, Str, Typed, observe
from matplotlib.axes import Axes
from matplotlib.collections import BrokenBarHCollection
from matplotlib.collections import PolyCollection
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -122,7 +122,7 @@ class LinePlotModel(Atom):
_fig_preview = Typed(Figure)
_ax_preview = Typed(Axes)
_lines_preview = List()
_bahr_preview = Typed(BrokenBarHCollection)
_bahr_preview = Typed(PolyCollection)

plot_type_preview = Typed(PlotTypes)
energy_range_preview = Typed(EnergyRangePresets)
Expand Down Expand Up @@ -177,7 +177,7 @@ class LinePlotModel(Atom):
show_exp_opt = Bool(False) # Flag: show spectrum preview

# Reference to artist responsible for displaying the selected range of energies on the plot
plot_energy_barh = Typed(BrokenBarHCollection)
plot_energy_barh = Typed(PolyCollection)
t_bar = Typed(object)

plot_exp_list = List()
Expand Down Expand Up @@ -727,7 +727,7 @@ def plot_selected_energy_range_original(self, *, e_low=None, e_high=None):
self.plot_energy_barh.remove()

# Create the new plot (based on new parameters if necessary
self.plot_energy_barh = BrokenBarHCollection.span_where(
self.plot_energy_barh = PolyCollection.span_where(
x_v, ymin=y_min, ymax=y_max, where=ss, facecolor="white", edgecolor="yellow", alpha=1
)
self._ax.add_collection(self.plot_energy_barh)
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def plot_selected_energy_range(self, *, axes, barh_existing, e_low=None, e_high=
barh_existing.remove()

# Create the new plot (based on new parameters if necessary
barh_new = BrokenBarHCollection.span_where(
barh_new = PolyCollection.span_where(
x_v, ymin=y_min, ymax=y_max, where=ss, facecolor="white", edgecolor="yellow", alpha=1
)
axes.add_collection(barh_new)
Expand Down
84 changes: 68 additions & 16 deletions pyxrf/model/load_data_from_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,20 +2318,24 @@ def map_data2D_srx_new_tiled(

d_xs, d_xs_sum, N_xs, d_xs2, d_xs2_sum, N_xs2 = None, None, 0, None, None, 0
if "xs_fluor" in data_stream0:
d_xs = data_stream0["xs_fluor"]
# The type of loaded data is tiled.client.array.DaskArrayClient
# If the data is not explicitly converted to da.array, then da.sum
# will fill the whole dataset, which is not desirable. This could be
# fixed in the future versions of Tiled.
d_xs = da.array(data_stream0["xs_fluor"])
d_xs_sum = da.sum(d_xs, 2)
N_xs = d_xs.shape[2]
elif "fluor" in data_stream0: # Old format
d_xs = data_stream0["fluor"]
d_xs = da.array(data_stream0["fluor"])
d_xs_sum = da.sum(d_xs, 2)
N_xs = d_xs.shape[2]

if "xs_fluor_xs2" in data_stream0:
d_xs2 = data_stream0["xs_fluor_xs2"]
d_xs2 = da.array(data_stream0["xs_fluor_xs2"])
d_xs2_sum = da.sum(d_xs2, 2)
N_xs2 = d_xs2.shape[2]
elif "fluor_xs2" in data_stream0: # Old format
d_xs2 = data_stream0["fluor_xs2"]
d_xs2 = da.array(data_stream0["fluor_xs2"])
d_xs2_sum = da.sum(d_xs2, 2)
N_xs2 = d_xs2.shape[2]

Expand Down Expand Up @@ -2536,10 +2540,12 @@ def swap_axes():

# Replace NaNs with 0s (in corrupt data rows).
loaded_data = {}
loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute())
# loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute())
loaded_data["det_sum"] = tmp_data_sum
if create_each_det:
for i in range(num_det):
loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute())
# loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute())
loaded_data["det" + str(i + 1)] = da.squeeze(tmp_data[:, :, i, :])

if save_scaler:
loaded_data["scaler_data"] = sclr.compute()
Expand Down Expand Up @@ -3519,6 +3525,8 @@ def save_data_to_hdf5(
Failed to write data to HDF5 file.
"""

time_start = ttime.time()

fpath = os.path.expanduser(fpath)
fpath = os.path.abspath(fpath)

Expand All @@ -3539,22 +3547,23 @@ def incorrect_type_msg(channel, data_type):
f"The data is converted from '{data_type}' to 'np.float32' before saving to file."
)

if "det_sum" in data and isinstance(data["det_sum"], np.ndarray):
if "det_sum" in data and isinstance(data["det_sum"], (np.ndarray, da.core.Array)):
if data["det_sum"].dtype != np.float32:
incorrect_type_msg("det_sum", data["det_sum"].dtype)
data["det_sum"] = data["det_sum"].astype(np.float32, copy=False)
sum_data = data["det_sum"]
sum_data_exists = True

for detname in xrf_det_list:
if detname in data and isinstance(data[detname], np.ndarray):
if detname in data and isinstance(data[detname], (np.ndarray, da.core.Array)):
if data[detname].dtype != np.float32:
incorrect_type_msg(detname, data[detname].dtype)
data[detname] = data[detname].astype(np.float32, copy=False)

if not sum_data_exists: # Don't compute it if it already exists
if sum_data is None:
sum_data = np.copy(data[detname])
# sum_data = np.copy(data[detname])
sum_data = data[detname].copy()
else:
sum_data += data[detname]

Expand Down Expand Up @@ -3598,18 +3607,59 @@ def incorrect_type_msg(channel, data_type):
for key, value in metadata_prepared.items():
metadata_grp.attrs[key] = value

# The following parameters control how data is loaded from Tiled
n_pixels_in_batch = 40000
n_tiled_download_retries = 10

def compute_batch_params(shape):
n_rows, n_cols = shape[0], shape[1]
n_rows_batch = max(int(n_pixels_in_batch / n_cols), 1) # Save at least one row
n_batches = int(n_rows / n_rows_batch)
if n_rows % n_rows_batch:
n_batches += 1
return n_rows, n_rows_batch, n_batches

def download_dataset(dset, data):
n_rows, n_rows_batch, n_batches = compute_batch_params(sum_data.shape)
for n in range(n_batches):
ns, ne = n * n_rows_batch, min((n + 1) * n_rows_batch, n_rows)
for retry in range(n_tiled_download_retries):
try:
dset[ns:ne, ...] = np.nan_to_num(np.array(data[ns:ne, ...]))
break
except Exception as ex:
logger.error(f"Failed to load the batch: {ex}")
if retry >= n_tiled_download_retries - 1:
raise TimeoutError("Failed to download data from Tiled server")
print(f" Number of saved rows: {ne}")

if create_each_det is True:
for detname in xrf_det_list:
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)
if not isinstance(sum_data, da.core.Array):
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)
else:
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", new_data.shape, compression="gzip")
print(f"Downloading data: channel {detname!r} ...")
download_dataset(ds_data, new_data)
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)

# summed data
if sum_data is not None:
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from channel sum"
if not isinstance(sum_data, da.core.Array):
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from channel sum"
else:
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", sum_data.shape, compression="gzip")
print("Downloading data: the sum of all channels ...")
download_dataset(ds_data, sum_data)
ds_data.attrs["comments"] = "Experimental data from channel sum"

# add positions
if "pos_names" in data:
Expand All @@ -3627,6 +3677,8 @@ def incorrect_type_msg(channel, data_type):
dataGrp.create_dataset("name", data=helper_encode_list(scaler_names))
dataGrp.create_dataset("val", data=scaler_data)

logger.info(f"Total data saving time: {ttime.time() - time_start}")

return fpath


Expand Down

0 comments on commit 0467cb9

Please sign in to comment.