diff --git a/pydatalab/pydatalab/apps/mri/blocks.py b/pydatalab/pydatalab/apps/mri/blocks.py index 69c2b6abf..8ec11d511 100644 --- a/pydatalab/pydatalab/apps/mri/blocks.py +++ b/pydatalab/pydatalab/apps/mri/blocks.py @@ -3,14 +3,15 @@ import pandas as pd from pydatalab.blocks.base import DataBlock -from pydatalab.bokeh_plots import mytheme +from pydatalab.logger import LOGGER +from pydatalab.bokeh_plots import DATALAB_BOKEH_THEME from pydatalab.file_utils import get_file_info_by_id class MRIBlock(DataBlock): blocktype = "mri" description = "In situ MRI" - accepted_file_extensions = (".csv", "2dseq") + accepted_file_extensions = (".csv", "2dseq", "visu_pars") @property def plot_functions(self): @@ -18,10 +19,10 @@ def plot_functions(self): @classmethod def load_2dseq( - self, + cls, location: str, image_size: tuple[int, int] = (512, 512), - ) -> pd.DataFrame: + ) -> list[np.ndarray]: if not isinstance(location, str): location = str(location) @@ -38,20 +39,28 @@ def load_2dseq( for i in range(num_images): # grab an image_size square slice from arrays image_arrays.append( - arr[i * image_pixels : (i + 1) * image_pixels].reshape(*image_size) + arr[i * image_pixels : (i + 1) * image_pixels].reshape(*image_size).copy() ) return image_arrays - def generate_mri_plots(self): + def generate_mri_plot(self): + """Generate image plots of MRI data.""" + from bokeh.layouts import column + from bokeh.models import ColumnDataSource, CustomJS, Slider, ColorBar, LinearColorMapper + from bokeh.plotting import figure + + if "file_id" not in self.data: + return None file_info = get_file_info_by_id(self.data["file_id"], update_if_live=True) image_array = self.load_2dseq( file_info["location"], ) - from bokeh.layout import column - from bokeh.models import ColumnDataSource, CustomJS, Slider - from bokeh.plotting import figure + + if len(image_array) == 0: + raise RuntimeError(f"Could not find any images in {file_info['location']}") + return None p = figure( sizing_mode="scale_width", @@ -61,6 +70,10 @@ def generate_mri_plots(self): ) image_source = ColumnDataSource(data={"image": [image_array[0]]}) + + cmap = bokeh.palettes.Viridis256 + color_mapper = bokeh.models.LogColorMapper(palette=cmap, low=1, high=np.max(image_array)) + p.image( image="image", source=image_source, @@ -68,22 +81,35 @@ def generate_mri_plots(self): y=0, dw=10, dh=10, - palette="Sunset11", - level="image", + color_mapper=color_mapper, ) - slider = Slider(start=0, end=len(image_array), step=1, value=0, title="Select image") + p.axis.visible = False + p.xgrid.visible = False + p.ygrid.visible = False + + # Set limits to edge of images defined by 10x10 + p.x_range.start = 0 + p.x_range.end = 10 + p.y_range.start = 0 + p.y_range.end = 10 + + color_bar = ColorBar(color_mapper=color_mapper) + + slider = Slider(start=0, end=len(image_array) - 1, step=1, value=0, title=f"Select image ({len(image_array)} images)") slider_callback = CustomJS( - args=dict(image_source=image_source, slider=slider), + args=dict(image_source=image_source, image_array=image_array, slider=slider), code=""" var selected_image_index = slider.value; -image_source.data["image"] = [image_arrays[selected_image_index]]; +image_source.data["image"] = [image_array[selected_image_index]]; image_source.change.emit(); """, ) slider.js_on_change("value", slider_callback) + p.add_layout(color_bar, "right") + layout = column(slider, p) - self.data["bokeh_plot_data"] = bokeh.embed.json_item(layout, theme=mytheme) + self.data["bokeh_plot_data"] = bokeh.embed.json_item(layout, theme=DATALAB_BOKEH_THEME) diff --git a/pydatalab/pydatalab/routes/v0_1/blocks.py b/pydatalab/pydatalab/routes/v0_1/blocks.py index d9a005c85..1d3a0c450 100644 --- a/pydatalab/pydatalab/routes/v0_1/blocks.py +++ b/pydatalab/pydatalab/routes/v0_1/blocks.py @@ -139,17 +139,30 @@ def _save_block_to_db(block: DataBlock) -> bool: overwriting previous data saved there. returns true if successful, false if unsuccessful """ - if block.data.get("item_id"): + + updated_block = block.to_db() + update = {"$set": {f"blocks_obj.{block.block_id}": updated_block}} + + print(updated_block.keys()) + + if block.data.get("collection_id"): + match = {"collection_id": block.data["collection_id"], **get_default_permissions(user_only=False)} + else: + match = {"block_id": block.data["block_id"], **get_default_permissions(user_only=False)} + + + try: result = flask_mongo.db.items.update_one( - {"item_id": block.data["item_id"], **get_default_permissions(user_only=False)}, - {"$set": {f"blocks_obj.{block.block_id}": block.to_db()}}, + match, + update, ) - elif block.data.get("collection_id"): - result = flask_mongo.db.collections.update_one( - {"item_id": block.data["collection_id"], **get_default_permissions(user_only=False)}, - {"$set": {f"blocks_obj.{block.block_id}": block.to_db()}}, + except pymongo.errors.DocumentTooLarge: + LOGGER.warning( + f"_save_block_to_db failed, trying to strip down the block plot to fit in the 16MB limit" ) + return False + if result.matched_count != 1: LOGGER.warning( f"_save_block_to_db failed, likely because item_id ({block.data.get('item_id')}), collection_id ({block.data.get('collection_id')}) and/or block_id ({block.block_id}) wasn't found"