diff --git a/test/test-util-bbox.py b/test/test-util-bbox.py new file mode 100644 index 0000000..101ce62 --- /dev/null +++ b/test/test-util-bbox.py @@ -0,0 +1,23 @@ +import numpy as np +from wirecell.util.bbox import * + + +def test_slice(): + assert slice(0,4) == union_slice(slice(0,4), slice(1,2)) # inside + assert slice(0,4) == union_slice(slice(1,2), slice(0,4)) # outside + assert slice(0,4) == union_slice(slice(0,1), slice(2,4)) # gap + assert slice(0,4) == union_slice(slice(0,4), slice(1,4)) # overlap + +def test_slice(): + assert np.all(np.array([0,1,2,3]) == union_array(slice(0,4), slice(1,2), order="ascending")) + assert np.all(np.array([1,0,2,3]) == union_array(slice(1,2), slice(0,4), order="seen")) + assert np.all(np.array([1,0,1,2,3]) == union_array(slice(1,2), slice(0,4), order=None)) + assert np.all(np.array([2,3,0]) == union_array(slice(2,4), slice(0,1), order="seen")) + assert np.all(np.array([3,2,1,0]) == union_array(slice(0,4), slice(1,4), order="descending")) + +def test_bbox(): + bb1 = (slice(0,2), slice(10,12)) + bb2 = (slice(3,5), slice(13,15)) + u1 = union(bb1, bb2, form="slices") + assert slice(0,5) == u1[0] + assert slice(10,15) == u1[1] diff --git a/test/test-util-peaks.py b/test/test-util-peaks.py new file mode 100644 index 0000000..c26ac9c --- /dev/null +++ b/test/test-util-peaks.py @@ -0,0 +1,19 @@ +#!/usr/bin/env pytest + +from wirecell.util.peaks import * +from wirecell.util.codec import JsonEncoder + +def roundtrip_dataclass(DCType): + t = DCType() + d = t.to_dict() + print(d) + j = json.dumps(d, cls=JsonEncoder) + d2 = json.loads(j) + print(d2) + assert d == d2 + t2 = DCType.from_dict(d2) + + +def test_dataclasses(): + roundtrip_dataclass(Peak1d) + diff --git a/wirecell/test/__main__.py b/wirecell/test/__main__.py index 96e6e85..d59b469 100644 --- a/wirecell/test/__main__.py +++ b/wirecell/test/__main__.py @@ -1,9 +1,15 @@ import math import click +import numpy +import functools +import dataclasses from wirecell.util import ario, plottools - +from wirecell.util.plottools import pages from wirecell.util.cli import context, log +from wirecell.test import ssss +from wirecell.util.peaks import select_activity +from wirecell.util.codec import json_dumps @context("test") def cli(ctx): @@ -30,6 +36,98 @@ def plot(ctx, name, datafile, output): +def ssss_args(func): + @click.option("--channel-ranges", default="0,800,1600,2560", + help="comma-separated list of channel idents defining ranges") + @click.option("--nsigma", default=3.0, + help="Relative threshold on signal in units of number of sigma of noise width") + @click.option("--nbins", default=50, + help="Number of bins over which to fit relative signal-splat difference") + @click.option("-o",'--output', default='/dev/stdout') + @click.argument("splat") + @click.argument("signal") + @functools.wraps(func) + def wrapper(*args, **kwds): + + kwds["splat"] = ssss.load_frame(kwds.pop("splat")) + kwds["signal"] = ssss.load_frame(kwds.pop("signal")) + + channel_ranges = kwds.pop("channel_ranges") + if channel_ranges: + channel_ranges = list(map(int,channel_ranges.split(","))) + channel_ranges = [slice(*cr) for cr in zip(channel_ranges[:-1], channel_ranges[1:])] + kwds["channel_ranges"] = channel_ranges + return func(*args, **kwds) + return wrapper + + +@cli.command("plot-ssss") +@click.option('--title', default='', help='extra title for plots') +@ssss_args +def plot_ssss(channel_ranges, nsigma, nbins, splat, signal, output, + title, **kwds): + ''' + Perform the simple splat / sim+signal process comparison test and make plots. + ''' + + with pages(output) as out: + + ssss.plot_frames(splat, signal, channel_ranges, title) + out.savefig() + + byplane = list() + + # Per channel range plots. + for pln, ch in enumerate(channel_ranges): + + spl = select_activity(splat.frame, ch, nsigma) + sig = select_activity(signal.frame, ch, nsigma) + + # Find the bbox that bounds the biggest splat object. + biggest = spl.plats.sort_by("sums")[-1] + bbox = spl.plats.bboxes[biggest] + + spl_act = spl.thresholded[bbox] + sig_act = sig.thresholded[bbox] + letter = "UVW"[pln] + ssss.plot_plane(spl_act, sig_act, nsigma=nsigma, + title=f'{letter}-plane {title}') + out.savefig() + + spl_qch = numpy.sum(spl.activity[bbox], axis=1) + sig_qch = numpy.sum(sig.activity[bbox], axis=1) + byplane.append((spl_qch, sig_qch)) + + + ssss.plot_metrics(byplane, nbins=nbins, + title=f'(splat - signal)/splat {title}') + + out.savefig() + +@cli.command("ssss-metrics") +@ssss_args +def ssss_metrics(channel_ranges, nsigma, nbins, splat, signal, output, **kwds): + ''' + Write the simple splat / sim+signal process comparison metrics to file. + ''' + + metrics = list() + for pln, ch in enumerate(channel_ranges): + spl = select_activity(splat.frame, ch, nsigma) + sig = select_activity(signal.frame, ch, nsigma) + + biggest = spl.plats.sort_by("sums")[-1] + bbox = spl.plats.bboxes[biggest] + + spl_qch = numpy.sum(spl.activity[bbox], axis=1) + sig_qch = numpy.sum(sig.activity[bbox], axis=1) + + m = ssss.calc_metrics(spl_qch, sig_qch, nbins) + metrics.append(dataclasses.asdict(m)) + + open(output,"w").write(json_dumps(metrics, indent=4)) + + def main(): cli(obj=dict()) diff --git a/wirecell/test/ssss.py b/wirecell/test/ssss.py new file mode 100644 index 0000000..3678bed --- /dev/null +++ b/wirecell/test/ssss.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python +'''The simple splat / sim+SP (ssss) test is used, in part, to reproduce the +signal biase, efficiency and resolution metric used in the MicroBooNE SP-1 +paper. + +''' + +import dataclasses +import numpy +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec +from wirecell import units +from wirecell.util.peaks import ( + BaselineNoise, + baseline_noise, + gauss as gauss_func +) + +def relbias(a,b): + ''' + Return (a-b)/a where a is nonzero, return zero o.w.. + ''' + rb = numpy.zeros_like(a) + ok = b>0 + rb[ok] = a[ok]/b[ok] - 1 + return rb + +@dataclasses.dataclass +class Frame: + ''' + Represent a "frame" loaded from file. + ''' + + filename: str | None + ''' + Filename from which the frame was taken. + ''' + frame: numpy.ndarray + ''' + The frame array + ''' + extent: tuple + ''' + Frame extent in time and channel: (t0,tf,cmin,cmax+1) + ''' + origin: str + ''' + Origin option for imshow(). + ''' + tick: float + ''' + Sample period in WCT system of units + ''' + +def load_frame(fname): + ''' + Load a frame with time values in explicit units. + ''' + fp = numpy.load(fname) + f = fp["frame_*_0"] + t = fp["tickinfo_*_0"] + c = fp["channels_*_0"] + + c2 = numpy.array(c) + numpy.sort(c2) + assert numpy.all(c == c2) + + cmin = numpy.min(c) + cmax = numpy.max(c) + nch = cmax-cmin+1 + ff = numpy.zeros((nch, f.shape[1]), dtype=f.dtype) + for irow, ch in enumerate(c): + ff[cmin-ch] = f[irow] + + t0 = t[0] + tick = t[1] + tf = t0 + f.shape[1]*tick + + # edge values + extent=(t0, tf+tick, cmin, cmax+1) + origin = "lower" # lower flips putting [0,0] at bottom + ff = numpy.flip(ff, axis=0) + return Frame(fname, ff, extent, origin, tick) + +def plot_frame(gs, fr, channel_ranges=None, which="splat", tit=""): + ''' + Plot one Frame + ''' + t0,tf,c0,cf = fr.extent + t0_us = t0/units.us + tf_us = tf/units.us + + gs = GridSpecFromSubplotSpec(2,2, subplot_spec=gs, + height_ratios = [5,1], width_ratios = [6,1]) + + fax = plt.subplot(gs[0,0]) + tax = plt.subplot(gs[1,0], sharex=fax) + cax = plt.subplot(gs[0,1], sharey=fax) + + cax.set_xlabel(which) + fax.set_ylabel("channel") + if which=="signal": + tax.set_xlabel("time [us]") + + if tit: + plt.title(tit) + plt.setp(fax.get_xticklabels(), visible=False) + plt.setp(cax.get_yticklabels(), visible=False) + if which=="splat": + plt.setp(tax.get_xticklabels(), visible=False) + + im = fax.imshow(fr.frame, extent=fr.extent, origin=fr.origin, + aspect='auto', vmax=500, cmap='hot_r') + + tval = fr.frame.sum(axis=0) + t = numpy.linspace(fr.extent[0], fr.extent[1], fr.frame.shape[1]+1,endpoint=True) + tax.plot(t[:-1], tval) + if channel_ranges: + for p,chans in zip("UVW",channel_ranges): # fixme: map to plane labels is only an assumption! + val = fr.frame[chans,:].sum(axis=0) + c1 = chans.start + c2 = chans.stop + tax.plot(t[:-1], val, label=p) + fax.plot([t0_us,tf_us], [c1,c1]) + fax.text(t0_us + 0.1*(tf_us-t0_us), c1 + 0.5*(c2-c1), p) + fax.plot([t0_us,tf_us], [c2-1,c2-1]) + tax.legend() + + cval = fr.frame.sum(axis=1) + c = numpy.linspace(fr.extent[2],fr.extent[3],fr.frame.shape[0]+1,endpoint=True) + cax.plot(cval, c[:-1]) + + return im + +def plot_frames(spl, sig, channel_ranges, title=""): + '''Plot the two Frame objects spl (splat) and sig (signal). + + Channel ranges gives list of pair of channel min/max to interpret as + contiguous rows on the Frame.array. + + ''' + fig = plt.figure() + pgs = GridSpec(1,2, figure=fig, width_ratios = [7,0.2]) + gs = GridSpecFromSubplotSpec(2, 1, pgs[0,0]) + im1 = plot_frame(gs[0], spl, channel_ranges, which="splat") + im2 = plot_frame(gs[1], sig, channel_ranges, which="signal") + fig.colorbar(im2, cax=plt.subplot(pgs[0,1])) + if title: + plt.suptitle(title) + plt.tight_layout() + +def plot_plane(spl_act, sig_act, nsigma=3.0, title=""): + ''' + Plot splat and signal activity for one plane. + + ''' + # bias of first w.r.t. second + bias1 = relbias(sig_act, spl_act) + bias2 = relbias(spl_act, sig_act) + + plt.clf() + fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True) + if title: + plt.suptitle(title) + args=dict(aspect='auto') + im1 = axes[0,0].imshow(sig_act, **args) + fig.colorbar(im1, ax=axes[0,0]) + im2 = axes[0,1].imshow(spl_act, **args) + fig.colorbar(im2, ax=axes[0,1]) + + args = dict(args, cmap='jet', vmin=-50, vmax=50) + + im3 = axes[1,0].imshow(100*bias1, **args) + fig.colorbar(im3, ax=axes[1,0]) + + im4 = axes[1,1].imshow(100*bias2, **args) + fig.colorbar(im4, ax=axes[1,1]) + + axes[0,0].set_title(f'signal {nsigma=}') + axes[0,1].set_title(f'splat {nsigma=}') + + axes[1,0].set_title(f'splat/signal - 1 [%]') + axes[1,1].set_title(f'signal/splat - 1 [%]') + + chan_tit = 'chans (rel)' + tick_tit = 'ticks (rel)' + axes[0,0].set_ylabel(chan_tit) + axes[1,0].set_ylabel(chan_tit) + axes[1,0].set_xlabel(tick_tit) + axes[1,1].set_xlabel(tick_tit) + + fig.subplots_adjust(right=0.85) + plt.tight_layout() + + +@dataclasses.dataclass +class Metrics: + '''Metrics about a signal vs splat''' + + neor: int + ''' Number of channels over which the rest are calculated. This can be less + than the number of channels in the original "activity" arrays if any given + channel has zero activity in both "signal" and "splat". ''' + + ineff: float + ''' The relative inefficiency. This is the fraction of channels with splat + but with zero signal. ''' + + fit: BaselineNoise + ''' + Gaussian fit to relative difference. .mu is bias and .sigma is resolution. + ''' + +def calc_metrics(spl_qch, sig_qch, nbins=50): + '''Return Metrics instance for splat and signal "channel activity" arrays. + - spl_qch :: 1D array giving total charge per channel from splat + - sig_qch :: 1D array giving total charge per channel from signala + - nbins :: the number of bins over which to fit the relative difference. + ''' + + # either-or, exclude channels where both are zero + eor = numpy.logical_or (spl_qch > 0, sig_qch > 0) + # both are nonzero + both = numpy.logical_and(spl_qch > 0, sig_qch > 0) + nosig = numpy.logical_and(spl_qch > 0, sig_qch == 0) + nospl = numpy.logical_and(spl_qch == 0, sig_qch > 0) + + neor = numpy.sum(eor) + nboth = numpy.sum(both) + # inefficiency + nnosig = numpy.sum(nosig) + ineff = nnosig/neor + # "over" efficiency + nnospl = numpy.sum(nospl) + oveff = nnospl/neor + + reldiff = (spl_qch[both] - sig_qch[both])/spl_qch[both], + vrange = 0.01*nbins/2 + bln = baseline_noise(reldiff, nbins, vrange) + + return Metrics(neor, ineff, bln) + +def plot_metrics(splat_signal_activity_pairs, nbins=50, title="", letters="UVW"): + plt.clf() + fig, axes = plt.subplots(nrows=2, ncols=3, sharey="row") + for pln, (spl_qch, sig_qch) in enumerate(splat_signal_activity_pairs): + + m = calc_metrics(spl_qch, sig_qch, nbins) + counts, edges = m.fit.hist + model = gauss_func(edges[:-1], m.fit.A, m.fit.mu, m.fit.sigma) + + letter = letters[pln] + + ax1,ax2 = axes[:,pln] + + ax1.plot(sig_qch, label='signal') + ax1.plot(spl_qch, label='splat') + ax1.set_xlabel('chans (rel)') + ax1.set_ylabel('electrons') + ax1.set_title(f'{letter} ineff={100*m.ineff:.1f}%') + ax1.legend() + + ax2.step(edges[:-1], counts, label='data') + ax2.plot(edges[:-1], model, label='fit') + ax2.set_title(f'mu={100*m.fit.mu:.2f}%\nsig={100*m.fit.sigma:.2f}%') + ax2.set_xlabel('difference [%]') + ax2.set_ylabel('counts') + ax2.legend() + + if title: + plt.suptitle(title) + else: + plt.suptitle('(splat - signal) / splat') + plt.tight_layout() diff --git a/wirecell/util/bbox.py b/wirecell/util/bbox.py new file mode 100644 index 0000000..d37a1bf --- /dev/null +++ b/wirecell/util/bbox.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +'''Utilities for dealing with arrays indices as "bounding boxes" + +Here, a "bbox" is an object that may be used to index a numpy array, potentially +a multi-dimensional one. + +A bbox is represented as a sequence of ranges, one for each dimension of an +array. + +A range may be in "slice form" or "array form". A slice range is a slice() +instance. An array range is any list-like sequence of integer indices. Indices +may span the array dimension sparsely, out of order and may repeat. + +Bbox functions may expand from slice to array form and may apply selection, +uniqueness and ordering to indices in an array form. + +''' +import numpy + +def union_array(*ranges, order="stack"): + ''' + Form a union of ranges in array form. + + The union may be sparse and ordered. + + Order determines post-processing of the union. + + - "ascending" :: sort unique indices in ascending order. + - "descending" :: sort unique indices in descending order. + - "seen" value :: sort unique indices in first-seen order. + - "stack" :: simply concatenate indices of each range (default). + + ''' + if not ranges: + return numpy.array((0,)) + u = list() + for one in ranges: + u.append(numpy.r_[one]) + u = numpy.hstack(u) + if order == "ascending": + return numpy.unique(u) + if order == "descending": + u = numpy.unique(u) + return u[::-1] + if order == "seen": + g,i = numpy.unique(u, return_index=True) + return g[numpy.argsort(i)] + + return u # "stack" by default + + +def union_slice(*slices): + '''Form union of ranges. + + A slice is returned that spans the union of the given slices. + + Input slices may define a .step value which is considered in forming the + union. In any case, the returned slice has no step defined. + + ''' + if not slices: + return slice(None,None,None) + inds = numpy.hstack([numpy.r_[s] for s in slices]) + return slice(numpy.min(inds), 1+numpy.max(inds)) + + +def union(*bboxes, form="slices"): + '''Form union of bboxes. + + If form is "slices" then the ranges of each dimension of the bboxes must all + be slices and the returned bbox will have union of ranges formed with + union_slice(). Otherwise, union_array() is used and "form" is passed as the + "order". + + ''' + if form == "slices": + return tuple([union_slice(*ranges) for ranges in zip(*bboxes)]) + return tuple([union_array(*ranges, order=form) for ranges in zip(*bboxes)]) + + +# todo: +# +# - bounds(array) -> return smaller array that removes any rows/cols that are +# fully masked. bonus to work on N-dimensions. Surprisingly, such a function +# is not found in numpy, scipy? diff --git a/wirecell/util/codec.py b/wirecell/util/codec.py new file mode 100644 index 0000000..72e5643 --- /dev/null +++ b/wirecell/util/codec.py @@ -0,0 +1,67 @@ +import json +import numpy +import dataclasses + +def to_pod(v): + ''' + Try hard to return v as POD + ''' + if isinstance(v, numpy.ndarray): + return v.tolist() + if isinstance(v, slice): + return [v.start, v.stop, v.step] + if dataclasses.is_dataclass(v): + return dataclasses.asdict(v, dict_factory = dict_factory) + return v + +def dict_factory(kv): + ''' + Try hard to convert dataclass key/values to POD. + + Sutable for calls like: + + ddict = dataclasses.asdict(dclass, dict_factory=dict_factory) + ''' + return {k:to_pod(v) for k,v in kv} + +@classmethod +def from_dict(cls, obj = {}): + ''' + Return instance of dataclass from dict-like POD. + ''' + dat = {f.name: f.type(obj.get(f.name, f.default)) + for f in dataclasses.fields(cls)} + return cls(**dat) + +def to_dict(self): + ''' + Try hard to return a dataclass as a dict of POD. + ''' + return dataclasses.asdict(self, dict_factory=dict_factory) + +def dataclass_dictify(cls): + ''' + Decorate a dataclass to add from_dict(cls) and to_dict(self) methods. + ''' + cls.from_dict = from_dict + cls.to_dict = to_dict + return cls + +class JsonEncoder(json.JSONEncoder): + + def default(self, obj): + if isinstance(obj, numpy.integer): + return int(obj) + if isinstance(obj, numpy.floating): + return float(obj) + if isinstance(obj, numpy.ndarray): + return obj.tolist() + if isinstance(obj, slice): + return (obj.start, obj.stop, obj.step) + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + return super().default(obj) + + +def json_dumps(obj, **kwds): + return json.dumps(obj, cls=JsonEncoder, **kwds) diff --git a/wirecell/util/peaks.py b/wirecell/util/peaks.py new file mode 100644 index 0000000..8c327c6 --- /dev/null +++ b/wirecell/util/peaks.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python +''' +Find and represent peaks in 1D and 2D frame and waveform arrays. +''' +import numpy +import numpy.ma as ma +import json +from pathlib import Path +import dataclasses +from typing import List, Tuple +from scipy.optimize import curve_fit +from scipy.signal import find_peaks, peak_widths +from scipy import ndimage +from math import sqrt, pi + +from wirecell.util.codec import dataclass_dictify +from wirecell.util.bbox import union as union_bbox + +sqrt2pi = sqrt(2*pi) + +def gauss(x, A, mu, sigma, *p): + ''' + Gaussian distribution model for fitting. + ''' + return A*numpy.exp(-0.5*((x-mu)/sigma)**2)/(sigma*sqrt2pi) + +@dataclasses.dataclass +class BaselineNoise: + ''' + Characterize baseline noise. + ''' + + A : float + ''' + Normalization constant + ''' + mu : float + ''' + Mean + ''' + sigma : float + ''' + Width + ''' + med : float + ''' + Median + ''' + cov : numpy.ndarray | None = None + ''' + Covariance matrix of fit. Non implies A,mu,sigma are statistical. + ''' + hist: tuple | None = None + ''' + The bin content and edges of the histgram that was fit + ''' + +def baseline_noise(array, bins=200, vrange=100): + '''Return a BaselineNoise derived from array a. + + This attempts to fit a Gaussian model to a histogram of array values + spanning given number of bins of value range given by vrange. The vrange + defines an extent about the MEDIAN VALUE. If it is a tuple it gives this + extent explicitly or if scalar the extent is symmetric, ie median+/-vrange. + + ''' + med = numpy.median(array) + if not isinstance(vrange, tuple): + vrange=(-vrange, vrange) + vrange=(med+vrange[0], med+vrange[1]) + + hist = numpy.histogram(array, bins=bins, range=vrange) + counts, edges = hist + + A = numpy.sum(counts) + mu = med + sig = sqrt(numpy.average(edges[:-1]**2, weights=counts)) + p0 = (A, mu, sig) + + try: + (A,mu,sig),cov = curve_fit(gauss, edges[:-1], counts, p0=p0) + except RuntimeError: + cov = None + return BaselineNoise(A, mu, sig, med, cov, hist) + + +@dataclasses.dataclass +@dataclass_dictify +class Peak1d: + ''' + Information about a peak in a 1D array. + ''' + + peak: int = 0 + '''The where along the waveform the peak resides.''' + + fwhm: float = 0.0 + '''Full-width of the peak at half-max.''' + + hh: float = 0.0 + '''Half of the height of the peak.''' + + left: float = 0.0 + '''The left side of the width measured in fractional indices.''' + + right: float = 0.0 + '''The right side of the width measured in fractional indices.''' + + tot: float = 0.0 + '''The sum of waveform values over the peak.''' + + mask: slice = slice(0,0) + '''A mask that captures the peak.''' + + A: float = 0.0 + '''The fit Gaussian normalization fit parameter. See gauss().''' + + mu: float = 0.0 + '''The fit Gaussian mean fit parameter. See gauss().''' + + sigma: float = 0.0 + '''The fit Gaussian sigma fit parameter. See gauss().''' + + cov: numpy.ndarray = numpy.zeros((0,)) + '''The covariance matrix of the fit.''' + + +def find1d(wave, npeaks=None, threshold=0): + '''Return measures of peaks in 1d waveform. + + - npeaks :: return only the number npeaks highest peaks. None returns all. + + - threshold :: the minimum value for a sample to be considered a peak. + + ''' + + # Find the peaks + peaks = find_peaks(wave, height = threshold)[0] + peaks = list(sorted(peaks, key=lambda p: wave[p])) + if npeaks is not None: + peaks = peaks[0:npeaks] + + # Characterize that peak full width half max + info = [numpy.array(peaks)] + info += peak_widths(wave, peaks, rel_height=0.5) + + # the "x" values in fits below are simply indices. + iota = numpy.arange(wave.size, dtype=int) + + # fit each peak + ret = list() + for peak, fwhm, hh, left, right in zip(*info): + + # Zero out activity outside of current peak. Fixme: this assumes peaks + # are well separated. + mask = slice(int(round(left-fwhm)), int(round(right+fwhm))) + tofit = numpy.zeros_like(wave) + tofit[mask] = wave[mask] + + # Guess initial parameters + A = numpy.sum(tofit) + p0 = (A, peak, 0.5*fwhm) + try: + fit,cov = curve_fit(gauss, iota, tofit, p0=p0) + except RuntimeError: + nnz = numpy.sum(tofit>0) + fit=p0 + cov=None + + one = Peak1d(peak=peak,fwhm=fwhm,hh=hh,left=left,right=right, + tot=A, mask=mask, + A = fit[0], mu = fit[1], sigma = fit[2], + cov = cov) + + ret.append(one) + + return ret; + + +@dataclasses.dataclass +class Plateaus: + + number: int = 0 + ''' + The number of objects. + ''' + + @property + def labels(self): + 'Array of labels (1-based counts)' + return 1 + self.indices + @property + def indices(self): + 'Array of indices (0-based counts)' + return numpy.arange(self.number) + + labeled: numpy.ndarray = numpy.zeros((0,0)) + ''' + The frame with labeled pixels for each object. + ''' + + bboxes: List[Tuple[slice]] = () + ''' + Bounding boxes of each object. + ''' + + sums: numpy.ndarray = numpy.zeros((0,)) + ''' + The total value of each object. + ''' + + counts: numpy.ndarray = numpy.zeros((0,)) + ''' + The number of pixels of each object. + ''' + + coms: numpy.ndarray = numpy.zeros((0,0)) + ''' + The center of mass of objects in pixel space. + ''' + + threshold: float = 0 + ''' + The threshold used to select the plateaus. + ''' + + def sort_by(self, what="sums", reverse=False): + ''' + Return an array of indices that orders the plateaus by some value. + ''' + what = getattr(self, what) + order = sorted(self.indices, key=lambda i: what[i]) + if reverse: + order.reverse() + return numpy.array(order) + + +def plateaus(frame, vthreshold=None): + '''Label contiguous regions above value threshold. + + Return a Plateaus. + + The vthreshold is compared directly to pixel values. If not given, it + default to min + 0.001 * (max-min) + + ''' + if vthreshold is None: + vmin = numpy.min(frame) + vmax = numpy.max(frame) + vthreshold = vmin + 0.001 * (vmax-vmin) + + thresh = frame > vthreshold + labels, nlabels = ndimage.label(thresh) + labs = numpy.arange(nlabels)+1 + + return Plateaus( + number = nlabels, + labeled = labels, + bboxes = ndimage.find_objects(labels, nlabels), + sums = numpy.array(ndimage.sum_labels(frame, labels, labs)), + counts = numpy.array(ndimage.sum_labels(thresh, labels, labs)), + coms = numpy.array(ndimage.center_of_mass(frame, labels, labs)), + threshold = vthreshold + ) + + +@dataclasses.dataclass +class SelectActivity: + ''' + Characterize activity in a subset of a frame + ''' + + selection: numpy.ndarray + ''' + The array over selected channels + ''' + + channels: list | slice + ''' + The channels in this selection + ''' + + bln: BaselineNoise + ''' + Calcualted baseline noise + ''' + + plats: list + ''' + The plateaus found + ''' + + bbox: tuple + ''' + A bounding box for all objects found + ''' + + nsigma: float + ''' + The number of sigma above median for plateau threshold. + ''' + + @property + def activity(self): + ''' + The selection array reduced by the median. + ''' + return self.selection - self.bln.med + + @property + def thresholded(self): + ''' + The activity with below threshold values masked + ''' + act = self.activity + return ma.array(act, mask = act <= self.plats.threshold) + +def select_activity(frame, ch, nsigma=3.0): + '''Select activity from a "frame" array spanning many channels. + + Given a full frame array, select channel rows given by ch, apply a threshold + that is nsigma*sigma above median to find bounding box of activity. + + Return the selected frame that has below-threshold pixels masked. + + ''' + plane = frame[ch, :] # select channels + bln = baseline_noise(plane) + thresh = bln.med + nsigma*bln.sigma + plats = plateaus(plane, thresh) + assert plats.number > 0 + bbox = union_bbox(*plats.bboxes) + return SelectActivity( + selection = plane, + channels = ch, + bln = bln, + plats = plats, + bbox = bbox, + nsigma = nsigma)