Skip to content

Commit

Permalink
Merge pull request #137 from DUNE/feature_flash_bug_fix
Browse files Browse the repository at this point in the history
Minor bug fix for flash making
  • Loading branch information
mjkramer authored Aug 31, 2024
2 parents f9307b4 + 73a7cd9 commit 34d2dd3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 44 deletions.
1 change: 0 additions & 1 deletion env-nompi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@ dependencies:
- h5flow>=0.1.0
- adc64format>=0.1.1
- git+https://github.com/cuddandr/pylandau.git
- dbscan1d
1 change: 0 additions & 1 deletion env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@ dependencies:
- h5flow>=0.1.0
- adc64format>=0.1.1
- git+https://github.com/cuddandr/pylandau.git
- dbscan1d
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,5 @@
'scikit-learn',
'h5flow>=0.2.0',
'pylandau @ git+https://github.com/cuddandr/pylandau.git#egg=pylandau',
'dbscan1d',
]
)
83 changes: 42 additions & 41 deletions src/proto_nd_flow/reco/light/flash_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from h5flow.core import H5FlowStage, resources
from h5flow.data import dereference

from dbscan1d.core import DBSCAN1D
import sklearn.cluster as cluster

import proto_nd_flow.util.units as units

Expand Down Expand Up @@ -74,7 +74,7 @@ def init(self, source_name):
self.sum_hits_dset = self.data_manager.get_dset(self.sum_hits_dset_name)
self.sipm_hits_dset = self.data_manager.get_dset(self.sipm_hits_dset_name)

self.dbs = DBSCAN1D(eps=self.eps, min_samples=self.min_samples)
self.dbs = cluster.DBSCAN(eps=self.eps, min_samples=self.min_samples)

# get waveform shape information
self.nadc = cwvfm_dset.dtype['samples'].shape[0]
Expand Down Expand Up @@ -117,6 +117,11 @@ def get_tpc_channels(self,itpc):

return(return_arr)

def get_extrema(self, input_array):
return np.column_stack((
input_array.min(),
input_array.max()))

def run(self, source_name, source_slice, cache):
super(FlashFinder, self).run(source_name, source_slice, cache)
events = cache[source_name]
Expand Down Expand Up @@ -153,9 +158,12 @@ def run(self, source_name, source_slice, cache):
tpc_hits = sum_hits[i,tpc_mask]
tpc_hits_idx = sum_hits_idx[i,tpc_mask]
if np.any(tpc_mask):
labels = self.dbs.fit_predict(tpc_hits["sample_idx"])
labels = self.dbs.fit_predict(tpc_hits["sample_idx"].reshape(-1,1)) # single feature

n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
# skip if there's no light hits
if n_clusters == 0:
continue
n_noise = np.count_nonzero(labels == -1)
tpc_flashes = np.empty((n_clusters+n_noise),dtype=self.flash_dtype)
ev_ref = np.empty((n_clusters+n_noise),dtype='u4') #ev ID for each flash
Expand All @@ -176,9 +184,9 @@ def run(self, source_name, source_slice, cache):
tpc_flashes[cl]["n_sum_hits"] = np.count_nonzero(labels==cl)

#Timing information
tpc_flashes[cl]["sample_range"] = get_extrema(tpc_hits[labels==cl]["sample_idx"])
tpc_flashes[cl]['hit_time_range'] = get_extrema(tpc_hits[labels==cl]["busy_ns"])
tpc_flashes[cl]['rising_spline_range'] = get_extrema(tpc_hits[labels==cl]["busy_ns"]
tpc_flashes[cl]["sample_range"] = self.get_extrema(tpc_hits[labels==cl]["sample_idx"])
tpc_flashes[cl]['hit_time_range'] = self.get_extrema(tpc_hits[labels==cl]["busy_ns"])
tpc_flashes[cl]['rising_spline_range'] = self.get_extrema(tpc_hits[labels==cl]["busy_ns"]
+tpc_hits[labels==cl]["rising_spline"])

#Hit intensity information
Expand All @@ -197,43 +205,36 @@ def run(self, source_name, source_slice, cache):
cwvfms[i,ch_idx[...,0], ch_idx[...,1], flash_slice],
axis=-1)

#Handle Noise events
# NOT NEEDED IF min_sample==1

# when it gets here, there must be at least one non-noisy clusters
ev_ref[:] = np.r_[source_slice][i]
sum_ref[:,0] = tpc_hits_idx
sum_ref[:,1] = labels
sum_ref[:,0] = tpc_hits_idx[labels>=0]
sum_ref[:,1] = labels[labels>=0]

flash_list.append(tpc_flashes)
ev_ref_list.append(ev_ref)
sum_ref_list.append(sum_ref)

flash_data = np.concatenate(flash_list)

# save data
flash_slice = self.data_manager.reserve_data(
self.flash_dset_name, len(flash_data))
if len(flash_data):
flash_data['id'] = np.r_[flash_slice]
self.data_manager.write_data(self.flash_dset_name, flash_slice, flash_data)

# save references
ev_ref_data = np.concatenate(ev_ref_list)
ref = np.array([(ev_idx,flash_data[flash_slice_idx]['id']) for flash_slice_idx, ev_idx in enumerate(ev_ref_data)])
self.data_manager.write_ref(source_name, self.flash_dset_name, ref)

flash_list_struc = [arr.shape[0] for arr in flash_list]
flash_list = np.split(flash_data, np.cumsum(flash_list_struc)[:-1])
for j, tpc_flash_slice in enumerate(flash_list):
sum_ref_list[j] = np.c_[sum_ref_list[j][:,0], tpc_flash_slice["id"][sum_ref_list[j][:,1]]]
ref_sum = np.concatenate(sum_ref_list)
self.data_manager.write_ref(
self.sum_hits_dset_name, self.flash_dset_name, ref_sum)
#self.data_manager.write_ref(
# self.sum_hits_dset_name, self.flash_dset_name, ref_sipm)

@staticmethod
def get_extrema(input_array):
return np.column_stack((
input_array.min(),
input_array.max()))

if len(flash_list) and len(ev_ref_list) and len(sum_ref_list):
flash_data = np.concatenate(flash_list)

# save data
flash_slice = self.data_manager.reserve_data(
self.flash_dset_name, len(flash_data))
if len(flash_data):
flash_data['id'] = np.r_[flash_slice]
self.data_manager.write_data(self.flash_dset_name, flash_slice, flash_data)

# save references
ev_ref_data = np.concatenate(ev_ref_list)
ref = np.array([(ev_idx,flash_data[flash_slice_idx]['id']) for flash_slice_idx, ev_idx in enumerate(ev_ref_data)])
self.data_manager.write_ref(source_name, self.flash_dset_name, ref)

flash_list_struc = [arr.shape[0] for arr in flash_list]
flash_list = np.split(flash_data, np.cumsum(flash_list_struc)[:-1])
for j, tpc_flash_slice in enumerate(flash_list):
sum_ref_list[j] = np.c_[sum_ref_list[j][:,0], tpc_flash_slice["id"][sum_ref_list[j][:,1]]]
ref_sum = np.concatenate(sum_ref_list)
self.data_manager.write_ref(
self.sum_hits_dset_name, self.flash_dset_name, ref_sum)
#self.data_manager.write_ref(
# self.sum_hits_dset_name, self.flash_dset_name, ref_sipm)

0 comments on commit 34d2dd3

Please sign in to comment.