Skip to content

Commit

Permalink
feat: discrete transfer functions
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Jul 7, 2023
1 parent 35b4bd0 commit 1e71b93
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 7 deletions.
60 changes: 55 additions & 5 deletions ipyvolume/pylab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import absolute_import
from __future__ import division
import pythreejs
from typing import List, Union


__all__ = [
'current',
Expand All @@ -29,6 +31,7 @@
'animation_control',
'gcc',
'transfer_function',
'transfer_function_discrete',
'plot_isosurface',
'volshow',
'save',
Expand Down Expand Up @@ -894,6 +897,48 @@ def gcc():
return current.container


def transfer_function_discrete(
n,
colors: List[str] = ["red", "green", "blue"],
labels: Union[None, List[str]] = None,
opacity: Union[float, List[float]] = 0.1,
enabled: Union[bool, List[bool]] = True,
controls=True,
):
"""Create a discrete transfer function with n layers.
Each (integer) value of the volumetric data maps to a single color.
:param n: number of layers
:param colors: list of colors, can be any valid HTML color string
:param labels: list of labels, if None, labels will be "Layer 0", "Layer 1", etc.
:param opacity: opacity of each layer, can be a single value or a list of values
:param enabled: whether each layer is enabled, can be a single value or a list of values
:param controls: whether to add the controls to the current container
"""
if isinstance(opacity, float):
opacity = [opacity] * len(colors)
if isinstance(enabled, bool):
enabled = [enabled] * len(colors)

def ensure_length(x):
repeat = (n + len(colors) - 1) // len(colors)
return (x * repeat)[:n]

if labels is None:
labels = []
for i in range(n):
labels.append(f"Layer {i}")

tf = ipv.TransferFunctionDiscrete(colors=ensure_length(colors), opacities=ensure_length(opacity), enabled=ensure_length(enabled), labels=ensure_length(labels))
gcf() # make sure a current container/figure exists
if controls:
current.container.children = [tf.control()] + current.container.children

return tf


def transfer_function(
level=[0.1, 0.5, 0.9], opacity=[0.01, 0.05, 0.1], level_width=0.1, controls=True, max_opacity=0.2
):
Expand Down Expand Up @@ -1029,8 +1074,7 @@ def volshow(
):
"""Visualize a 3d array using volume rendering.
Currently only 1 volume can be rendered.
If the data is of type int8 or bool, :any:`a discrete transfer function will be used <ipv.discrete_transfer_function>`
:param data: 3d numpy array
:param origin: origin of the volume data, this is to match meshes which have a different origin
Expand All @@ -1040,7 +1084,7 @@ def volshow(
:param float data_max: maximum value to consider for data, if None, computed using np.nanmax
:parap int max_shape: maximum shape for the 3d cube, if larger, the data is reduced by skipping/slicing (data[::N]),
set to None to disable.
:param tf: transfer function (or a default one)
:param tf: transfer function (or a default one, based on the data)
:param bool stereo: stereo view for virtual reality (cardboard and similar VR head mount)
:param ambient_coefficient: lighting parameter
:param diffuse_coefficient: lighting parameter
Expand All @@ -1060,12 +1104,18 @@ def volshow(
"""
fig = gcf()

if tf is None:
tf = transfer_function(level, opacity, level_width, controls=controls, max_opacity=max_opacity)
if data_min is None:
data_min = np.nanmin(data)
if data_max is None:
data_max = np.nanmax(data)
if tf is None:
if (data.dtype == np.uint8) or (data.dtype == bool):
if data.dtype == bool:
data_max = 1

tf = transfer_function_discrete(n=data_max + 1)
else:
tf = transfer_function(level, opacity, level_width, controls=controls, max_opacity=max_opacity)
if memorder == 'F':
data = data.T

Expand Down
11 changes: 10 additions & 1 deletion ipyvolume/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import pytest

import ipyvolume
import ipyvolume.pylab as p3
import ipyvolume as ipv
import ipyvolume.pylab as p3
import ipyvolume.examples
import ipyvolume.datasets
import ipyvolume.utils
Expand Down Expand Up @@ -303,6 +303,15 @@ def test_volshow():
p3.save("tmp/ipyolume_volume.html")


def test_volshow_discrete():
boolean_volume = np.random.random((10, 10, 10)) > 0.5
ipv.figure()
vol = ipv.volshow(boolean_volume)
assert isinstance(vol.tf, ipyvolume.TransferFunctionDiscrete)
assert len(vol.tf.colors) == 2
# int8_volume = np.random.randint(0, 255, size=(10, 10, 10), dtype=np.uint8)


def test_volshow_max_shape():
x, y, z = ipyvolume.examples.xyz(shape=32)
Im = x * y * z
Expand Down
21 changes: 20 additions & 1 deletion ipyvolume/transferfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import absolute_import

__all__ = ['TransferFunction', 'TransferFunctionJsBumps', 'TransferFunctionWidgetJs3', 'TransferFunctionWidget3']
__all__ = ['TransferFunction', 'TransferFunctionDiscrete', 'TransferFunctionJsBumps', 'TransferFunctionWidgetJs3', 'TransferFunctionWidget3']

import numpy as np
import ipywidgets as widgets # we should not have widgets under two names
Expand All @@ -12,6 +12,7 @@

import ipyvolume._version
from ipyvolume import serialize
import ipyvuetify as v


N = 1024
Expand All @@ -26,11 +27,29 @@ class TransferFunction(widgets.DOMWidget):
_model_module = Unicode('ipyvolume').tag(sync=True)
_view_module = Unicode('ipyvolume').tag(sync=True)
style = Unicode("height: 32px; width: 100%;").tag(sync=True)
# rgba should be a 2d array of shape (N, 4), where the last dimension is the rgba value
# with values between 0 and 1
rgba = Array(default_value=None, allow_none=True).tag(sync=True, **serialize.ndarray_serialization)
_view_module_version = Unicode(semver_range_frontend).tag(sync=True)
_model_module_version = Unicode(semver_range_frontend).tag(sync=True)


class TransferFunctionDiscrete(TransferFunction):
_model_name = Unicode('TransferFunctionDiscreteModel').tag(sync=True)
colors = traitlets.List(traitlets.Unicode(), default_value=["red", "#0f0"]).tag(sync=True)
opacities = traitlets.List(traitlets.CFloat(), default_value=[0.01, 0.01]).tag(sync=True)
enabled = traitlets.List(traitlets.Bool(), default_value=[True, True]).tag(sync=True)
labels = traitlets.List(traitlets.Unicode(), default_value=["label1", "label2"]).tag(sync=True)

def control(self):
return TransferFunctionDiscreteView(tf=self)


class TransferFunctionDiscreteView(v.VuetifyTemplate):
template_file = (__file__, 'vue/tf_discrete.vue')
tf = traitlets.Instance(TransferFunctionDiscrete).tag(sync=True, **widgets.widget_serialization)


class TransferFunctionJsBumps(TransferFunction):
_model_name = Unicode('TransferFunctionJsBumpsModel').tag(sync=True)
_model_module = Unicode('ipyvolume').tag(sync=True)
Expand Down
49 changes: 49 additions & 0 deletions js/src/tf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import * as widgets from "@jupyter-widgets/base";
import {default as ndarray_pack} from "ndarray-pack";
import * as serialize from "./serialize.js";
import {semver_range} from "./utils";
import _ from "underscore";
import * as THREE from "three";

export
class TransferFunctionView extends widgets.DOMWidgetView {
Expand Down Expand Up @@ -112,6 +114,53 @@ class TransferFunctionJsBumpsModel extends TransferFunctionModel {
}
}

export
class TransferFunctionDiscreteModel extends TransferFunctionModel {

constructor(...args) {
super(...args);
this.on("change:colors", this.recalculate_rgba, this);
this.on("change:opacities", this.recalculate_rgba, this);
this.on("change:enabled", this.recalculate_rgba, this);
this.recalculate_rgba();
}
defaults() {
return {
...super.defaults(),
_model_name : "TransferFunctionDiscreteModel",
color: ["red", "#0f0"],
opacities: [0.01, 0.01],
enabled: [true, true],
};
}

recalculate_rgba() {
const rgba = [];
const colors = _.map(this.get("colors"), (color : string) => {
return (new THREE.Color(color)).toArray();
});
const enabled = this.get("enabled");
const opacities = this.get("opacities");
(window as any).rgba = rgba;
(window as any).tfjs = this;
const N = colors.length;
for (let i = 0; i < N; i++) {
const color = [...colors[i], opacities[i]]; // red, green, blue and alpha
color[3] = Math.min(1, color[3]); // clip alpha
if(!enabled[i]) {
color[3] = 0;
}
rgba.push(color);
}
// because we want the shader to sample the center pixel, if we add one extra pixel in the texture
// all samples should be shiften by epsilon so the sample the center of the transfer function
rgba.push([0, 0, 0, 0]);
const rgba_array = ndarray_pack(rgba);
this.set("rgba", rgba_array);
this.save_changes();
}
}

export
class TransferFunctionWidgetJs3Model extends TransferFunctionModel {

Expand Down

0 comments on commit 1e71b93

Please sign in to comment.