Skip to content

Commit

Permalink
fix: robust channel to electrode mapping and ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
ttngu207 committed Apr 9, 2024
1 parent 7c388b5 commit 9632896
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions element_array_ephys/ephys_organoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def make(self, key):
try:
data = intanrhdreader.load_file(file)
except OSError:
raise OSError(f"OS error occured when loading file {file.name}")
raise OSError(f"OS error occurred when loading file {file.name}")

if not header:
header = data.pop("header")
Expand Down Expand Up @@ -736,14 +736,14 @@ def make(self, key):
sorting_dir / "si_sorting.pkl"
)

unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel(
we, outputs="index"
) # {unit: peak_channel_index}
unit_peak_channel: dict[int, int] = si.get_template_extremum_channel(
we, outputs="id"
) # {unit: peak_channel_id}

spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
# {unit: spike_count}

spikes = si_sorting.to_spike_vector(extremum_channel_inds=unit_peak_channel_map)
spikes = si_sorting.to_spike_vector()

# Get electrode & channel info
probe_info = (probe.Probe * EphysSessionProbe & key).fetch1()
Expand All @@ -762,8 +762,13 @@ def make(self, key):
)
electrode_query &= f'electrode IN {tuple(probe_info["used_electrodes"])}'

channel_info = electrode_query.fetch(as_dict=True, order_by="electrode")
channel_info: dict[int, dict] = {ch_idx: ch for ch_idx, ch in enumerate(channel_info)}
channel2electrode_map = electrode_query.fetch(as_dict=True)
channel2electrode_map: dict[int, dict] = {
chn.pop("channel_idx"): chn for chn in channel2electrode_map
} # e.g., {0: {'organoid_id': 'O09',

# reorder channel2electrode_map according to recording channel ids
channel2electrode_map = {chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids}

# Get unit id to quality label mapping
try:
Expand All @@ -783,15 +788,15 @@ def make(self, key):
# Get electrode where peak unit activity is recorded
peak_electrode_ind = np.array(
[
channel_info[unit_peak_channel_map[unit_id]]["electrode"]
channel2electrode_map[unit_peak_channel[unit_id]]["electrode"]
for unit_id in si_sorting.unit_ids
]
)

# Get channel depth
channel_depth_ind = np.array(
[
channel_info[unit_peak_channel_map[unit_id]]["y_coord"]
channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"]
for unit_id in si_sorting.unit_ids
]
)
Expand All @@ -816,7 +821,7 @@ def make(self, key):
units.append(
{
**key,
**channel_info[unit_peak_channel_map[unit_id]],
**channel2electrode_map[unit_peak_channel[unit_id]],
"unit": unit_id,
"cluster_quality_label": cluster_quality_label_map.get(
unit_id, "n.a."
Expand Down Expand Up @@ -908,9 +913,9 @@ def make(self, key):
)
electrode_query &= f'electrode IN {tuple(probe_info["used_electrodes"])}'

channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx")
channel_info: dict[int, dict] = {
ch.pop("channel_idx"): key | ch for ch in channel_info
channel2electrode_map = electrode_query.fetch(as_dict=True)
channel2electrode_map: dict[int, dict] = {
chn.pop("channel_idx"): chn for chn in channel2electrode_map
} # e.g., {0: {'organoid_id': 'O09',

waveform_dir = output_dir / sorter_name / "waveform"
Expand All @@ -921,8 +926,11 @@ def make(self, key):
unit_id_to_peak_channel_map: dict[int, np.ndarray] = (
si.ChannelSparsity.from_best_channels(
we, 1, peak_sign="neg"
).unit_id_to_channel_indices
) # {unit: peak_channel_index}
).unit_id_to_channel_ids
) # {unit: peak_channel_id}

# reorder channel2electrode_map according to recording channel ids
channel2electrode_map = {chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids}

# Get mean waveform for each unit from all channels
mean_waveforms = we.get_all_templates(
Expand All @@ -933,23 +941,24 @@ def make(self, key):
unit_electrode_waveforms = []

for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"):
unit_waveforms = we.get_template(
unit_id=unit["unit"], mode="average", force_dense=True
) # (sample x channel)
peak_chn_idx = list(we.channel_ids).index(unit_id_to_peak_channel_map[unit["unit"]][0])
unit_peak_waveform.append(
{
**unit,
"peak_electrode_waveform": we.get_template(
unit_id=unit["unit"], mode="average", force_dense=True
)[:, unit_id_to_peak_channel_map[unit["unit"]][0]],
"peak_electrode_waveform": unit_waveforms[:, peak_chn_idx],
}
)

unit_electrode_waveforms.extend(
[
{
**unit,
**channel_info[c],
"waveform_mean": mean_waveforms[unit["unit"] - 1, :, c],
**channel2electrode_map[c],
"waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx],
}
for c in channel_info
for c_idx, c in enumerate(channel2electrode_map)
]
)

Expand Down

0 comments on commit 9632896

Please sign in to comment.