diff --git a/scale_aware_air_sea/utils.py b/scale_aware_air_sea/utils.py index b22a243..15f2049 100644 --- a/scale_aware_air_sea/utils.py +++ b/scale_aware_air_sea/utils.py @@ -16,19 +16,27 @@ def open_zarr(mapper, chunks={}): inline_array=True ) -def maybe_save_and_reload(ds, path, overwrite=False, fs=None): +def maybe_save_and_reload(ds, path, overwrite=False, fs=None, to_zarr_kwargs={}, open_dataset_kwargs={}): if fs is None: fs = gcsfs.GCSFileSystem() + + open_dataset_kwargs.setdefault('engine','zarr') + open_dataset_kwargs.setdefault('chunks',{}) + + if overwrite + to_zarr_kwargs.setdefault('mode','w') if not fs.exists(path): print(f'Saving the dataset to zarr at {path}') - ds.to_zarr(path) + ds.to_zarr(path, **to_zarr_kwargs) elif fs.exists(path) and overwrite: + print(f'Overwriting dataset at {path}') - ds.to_zarr(path, mode='w') + ds.to_zarr(path, **to_zarr_kwargs) print(f"Reload dataset from {path}") - ds_reloaded = xr.open_dataset(path, engine='zarr', chunks={}) + + ds_reloaded = xr.open_dataset(path, **open_dataset_kwargs) return ds_reloaded def filter_inputs(