Skip to content

Commit

Permalink
Factor code from WCT test helper ssss-pdsp.py into more proper wcpy.
Browse files Browse the repository at this point in the history
  • Loading branch information
brettviren committed Jan 5, 2024
1 parent a90b916 commit d456b6b
Show file tree
Hide file tree
Showing 7 changed files with 906 additions and 1 deletion.
23 changes: 23 additions & 0 deletions test/test-util-bbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
from wirecell.util.bbox import *


def test_slice():
assert slice(0,4) == union_slice(slice(0,4), slice(1,2)) # inside
assert slice(0,4) == union_slice(slice(1,2), slice(0,4)) # outside
assert slice(0,4) == union_slice(slice(0,1), slice(2,4)) # gap
assert slice(0,4) == union_slice(slice(0,4), slice(1,4)) # overlap

def test_slice():
assert np.all(np.array([0,1,2,3]) == union_array(slice(0,4), slice(1,2), order="ascending"))
assert np.all(np.array([1,0,2,3]) == union_array(slice(1,2), slice(0,4), order="seen"))
assert np.all(np.array([1,0,1,2,3]) == union_array(slice(1,2), slice(0,4), order=None))
assert np.all(np.array([2,3,0]) == union_array(slice(2,4), slice(0,1), order="seen"))
assert np.all(np.array([3,2,1,0]) == union_array(slice(0,4), slice(1,4), order="descending"))

def test_bbox():
bb1 = (slice(0,2), slice(10,12))
bb2 = (slice(3,5), slice(13,15))
u1 = union(bb1, bb2, form="slices")
assert slice(0,5) == u1[0]
assert slice(10,15) == u1[1]
19 changes: 19 additions & 0 deletions test/test-util-peaks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env pytest

from wirecell.util.peaks import *
from wirecell.util.codec import JsonEncoder

def roundtrip_dataclass(DCType):
t = DCType()
d = t.to_dict()
print(d)
j = json.dumps(d, cls=JsonEncoder)
d2 = json.loads(j)
print(d2)
assert d == d2
t2 = DCType.from_dict(d2)


def test_dataclasses():
roundtrip_dataclass(Peak1d)

100 changes: 99 additions & 1 deletion wirecell/test/__main__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import math
import click
import numpy
import functools
import dataclasses

from wirecell.util import ario, plottools

from wirecell.util.plottools import pages
from wirecell.util.cli import context, log
from wirecell.test import ssss
from wirecell.util.peaks import select_activity
from wirecell.util.codec import json_dumps

@context("test")
def cli(ctx):
Expand All @@ -30,6 +36,98 @@ def plot(ctx, name, datafile, output):



def ssss_args(func):
@click.option("--channel-ranges", default="0,800,1600,2560",
help="comma-separated list of channel idents defining ranges")
@click.option("--nsigma", default=3.0,
help="Relative threshold on signal in units of number of sigma of noise width")
@click.option("--nbins", default=50,
help="Number of bins over which to fit relative signal-splat difference")
@click.option("-o",'--output', default='/dev/stdout')
@click.argument("splat")
@click.argument("signal")
@functools.wraps(func)
def wrapper(*args, **kwds):

kwds["splat"] = ssss.load_frame(kwds.pop("splat"))
kwds["signal"] = ssss.load_frame(kwds.pop("signal"))

channel_ranges = kwds.pop("channel_ranges")
if channel_ranges:
channel_ranges = list(map(int,channel_ranges.split(",")))
channel_ranges = [slice(*cr) for cr in zip(channel_ranges[:-1], channel_ranges[1:])]
kwds["channel_ranges"] = channel_ranges
return func(*args, **kwds)
return wrapper


@cli.command("plot-ssss")
@click.option('--title', default='', help='extra title for plots')
@ssss_args
def plot_ssss(channel_ranges, nsigma, nbins, splat, signal, output,
title, **kwds):
'''
Perform the simple splat / sim+signal process comparison test and make plots.
'''

with pages(output) as out:

ssss.plot_frames(splat, signal, channel_ranges, title)
out.savefig()

byplane = list()

# Per channel range plots.
for pln, ch in enumerate(channel_ranges):

spl = select_activity(splat.frame, ch, nsigma)
sig = select_activity(signal.frame, ch, nsigma)

# Find the bbox that bounds the biggest splat object.
biggest = spl.plats.sort_by("sums")[-1]
bbox = spl.plats.bboxes[biggest]

spl_act = spl.thresholded[bbox]
sig_act = sig.thresholded[bbox]
letter = "UVW"[pln]
ssss.plot_plane(spl_act, sig_act, nsigma=nsigma,
title=f'{letter}-plane {title}')
out.savefig()

spl_qch = numpy.sum(spl.activity[bbox], axis=1)
sig_qch = numpy.sum(sig.activity[bbox], axis=1)
byplane.append((spl_qch, sig_qch))


ssss.plot_metrics(byplane, nbins=nbins,
title=f'(splat - signal)/splat {title}')

out.savefig()

@cli.command("ssss-metrics")
@ssss_args
def ssss_metrics(channel_ranges, nsigma, nbins, splat, signal, output, **kwds):
'''
Write the simple splat / sim+signal process comparison metrics to file.
'''

metrics = list()
for pln, ch in enumerate(channel_ranges):
spl = select_activity(splat.frame, ch, nsigma)
sig = select_activity(signal.frame, ch, nsigma)

biggest = spl.plats.sort_by("sums")[-1]
bbox = spl.plats.bboxes[biggest]

spl_qch = numpy.sum(spl.activity[bbox], axis=1)
sig_qch = numpy.sum(sig.activity[bbox], axis=1)

m = ssss.calc_metrics(spl_qch, sig_qch, nbins)
metrics.append(dataclasses.asdict(m))

open(output,"w").write(json_dumps(metrics, indent=4))


def main():
cli(obj=dict())

Expand Down
Loading

0 comments on commit d456b6b

Please sign in to comment.