Skip to content

Commit

Permalink
Optimized memory use for lazy loading of dask arrays from tiled.
Browse files Browse the repository at this point in the history
  • Loading branch information
dmgav committed Jun 8, 2024
1 parent 4419b9b commit c93f542
Showing 1 changed file with 68 additions and 16 deletions.
84 changes: 68 additions & 16 deletions pyxrf/model/load_data_from_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,20 +2318,24 @@ def map_data2D_srx_new_tiled(

d_xs, d_xs_sum, N_xs, d_xs2, d_xs2_sum, N_xs2 = None, None, 0, None, None, 0
if "xs_fluor" in data_stream0:
d_xs = data_stream0["xs_fluor"]
# The type of loaded data is tiled.client.array.DaskArrayClient
# If the data is not explicitly converted to da.array, then da.sum
# will fill the whole dataset, which is not desirable. This could be
# fixed in the future versions of Tiled.
d_xs = da.array(data_stream0["xs_fluor"])
d_xs_sum = da.sum(d_xs, 2)
N_xs = d_xs.shape[2]
elif "fluor" in data_stream0: # Old format
d_xs = data_stream0["fluor"]
d_xs = da.array(data_stream0["fluor"])
d_xs_sum = da.sum(d_xs, 2)
N_xs = d_xs.shape[2]

if "xs_fluor_xs2" in data_stream0:
d_xs2 = data_stream0["xs_fluor_xs2"]
d_xs2 = da.array(data_stream0["xs_fluor_xs2"])
d_xs2_sum = da.sum(d_xs2, 2)
N_xs2 = d_xs2.shape[2]
elif "fluor_xs2" in data_stream0: # Old format
d_xs2 = data_stream0["fluor_xs2"]
d_xs2 = da.array(data_stream0["fluor_xs2"])
d_xs2_sum = da.sum(d_xs2, 2)
N_xs2 = d_xs2.shape[2]

Expand Down Expand Up @@ -2536,10 +2540,12 @@ def swap_axes():

# Replace NaNs with 0s (in corrupt data rows).
loaded_data = {}
loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute())
# loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute())
loaded_data["det_sum"] = tmp_data_sum
if create_each_det:
for i in range(num_det):
loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute())
# loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute())
loaded_data["det" + str(i + 1)] = da.squeeze(tmp_data[:, :, i, :])

if save_scaler:
loaded_data["scaler_data"] = sclr.compute()
Expand Down Expand Up @@ -3519,6 +3525,8 @@ def save_data_to_hdf5(
Failed to write data to HDF5 file.
"""

time_start = ttime.time()

fpath = os.path.expanduser(fpath)
fpath = os.path.abspath(fpath)

Expand All @@ -3539,22 +3547,23 @@ def incorrect_type_msg(channel, data_type):
f"The data is converted from '{data_type}' to 'np.float32' before saving to file."
)

if "det_sum" in data and isinstance(data["det_sum"], np.ndarray):
if "det_sum" in data and isinstance(data["det_sum"], (np.ndarray, da.core.Array)):
if data["det_sum"].dtype != np.float32:
incorrect_type_msg("det_sum", data["det_sum"].dtype)
data["det_sum"] = data["det_sum"].astype(np.float32, copy=False)
sum_data = data["det_sum"]
sum_data_exists = True

for detname in xrf_det_list:
if detname in data and isinstance(data[detname], np.ndarray):
if detname in data and isinstance(data[detname], (np.ndarray, da.core.Array)):
if data[detname].dtype != np.float32:
incorrect_type_msg(detname, data[detname].dtype)
data[detname] = data[detname].astype(np.float32, copy=False)

if not sum_data_exists: # Don't compute it if it already exists
if sum_data is None:
sum_data = np.copy(data[detname])
# sum_data = np.copy(data[detname])
sum_data = data[detname].copy()
else:
sum_data += data[detname]

Expand Down Expand Up @@ -3598,18 +3607,59 @@ def incorrect_type_msg(channel, data_type):
for key, value in metadata_prepared.items():
metadata_grp.attrs[key] = value

# The following parameters control how data is loaded from Tiled
n_pixels_in_batch = 40000
n_tiled_download_retries = 10

def compute_batch_params(shape):
n_rows, n_cols = shape[0], shape[1]
n_rows_batch = max(int(n_pixels_in_batch / n_cols), 1) # Save at least one row
n_batches = int(n_rows / n_rows_batch)
if n_rows % n_rows_batch:
n_batches += 1
return n_rows, n_rows_batch, n_batches

def download_dataset(dset, data):
n_rows, n_rows_batch, n_batches = compute_batch_params(sum_data.shape)
for n in range(n_batches):
ns, ne = n * n_rows_batch, min((n + 1) * n_rows_batch, n_rows)
for retry in range(n_tiled_download_retries):
try:
dset[ns:ne, ...] = np.array(data[ns:ne, ...])
break
except Exception as ex:
logger.error(f"Failed to load the batch: {ex}")
if retry >= n_tiled_download_retries - 1:
raise TimeoutError("Failed to download data from Tiled server")
print(f" Number of saved rows: {ne}")

if create_each_det is True:
for detname in xrf_det_list:
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)
if not isinstance(sum_data, da.core.Array):
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)
else:
new_data = data[detname]
dataGrp = f.create_group(interpath + "/" + detname)
ds_data = dataGrp.create_dataset("counts", new_data.shape, compression="gzip")
print(f"Downloading data: channel {detname!r} ...")
download_dataset(ds_data, new_data)
ds_data.attrs["comments"] = "Experimental data from {}".format(detname)

# summed data
if sum_data is not None:
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from channel sum"
if not isinstance(sum_data, da.core.Array):
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip")
ds_data.attrs["comments"] = "Experimental data from channel sum"
else:
dataGrp = f.create_group(interpath + "/detsum")
ds_data = dataGrp.create_dataset("counts", sum_data.shape, compression="gzip")
print("Downloading data: the sum of all channels ...")
download_dataset(ds_data, sum_data)
ds_data.attrs["comments"] = "Experimental data from channel sum"

# add positions
if "pos_names" in data:
Expand All @@ -3627,6 +3677,8 @@ def incorrect_type_msg(channel, data_type):
dataGrp.create_dataset("name", data=helper_encode_list(scaler_names))
dataGrp.create_dataset("val", data=scaler_data)

logger.info(f"Total data saving time: {ttime.time() - time_start}")

return fpath


Expand Down

0 comments on commit c93f542

Please sign in to comment.