Skip to content

Commit

Permalink
reverted cluster amp calculation to use whitened templates
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Apr 2, 2024
1 parent c00f68e commit ebdaf5d
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
whitening_mat
+ 1e-5 * torch.eye(whitening_mat.shape[0]).to(whitening_mat.device)
)
#whitening_mat_inv = np.linalg.inv(whitening_mat + 1e-5 * np.eye(whitening_mat.shape[0]))
np.save((results_dir / 'whitening_mat.npy'), whitening_mat.cpu())
np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv.cpu())

Expand Down Expand Up @@ -195,19 +194,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,

# template properties
similar_templates = CCG.similarity(Wall, ops['wPCA'].contiguous(), nt=ops['nt'])
temp_amplitudes = ((Wall**2).sum(axis=(-2,-1))**0.5).cpu().numpy()
template_amplitudes = ((Wall**2).sum(axis=(-2,-1))**0.5).cpu().numpy()
templates = (Wall.unsqueeze(-1).cpu() * ops['wPCA'].cpu()).sum(axis=-2).numpy()
templates = templates.transpose(0,2,1)
# normalize templates by amplitude
templates = templates / temp_amplitudes[:, np.newaxis, np.newaxis]
# TODO: check if this helps / hurts going between snippets and templates
# scale should not change when switching between
# TODO: post issue on phy github asking where 'amp' is actually coming from,
# other issue answers are old and don't seem to point anywhere relevant
templates = templates / template_amplitudes[:, np.newaxis, np.newaxis]
templates_ind = np.tile(np.arange(Wall.shape[1])[np.newaxis, :], (templates.shape[0],1))
np.save((results_dir / 'similar_templates.npy'), similar_templates)
np.save((results_dir / 'templates.npy'), templates)
np.save((results_dir / 'templates_ind.npy'), templates_ind)
# get unwhitened template amplitudes to use as cluster_Amplitudes
iwrot = whitening_mat_inv.to(Wall.device)
unwhitened = torch.einsum('jk, ikl -> ijl', iwrot, Wall)
template_amplitudes = ((unwhitened**2).sum(axis=(-2,-1))**0.5).cpu().numpy()

# contamination ratio
acg_threshold = ops['settings']['acg_threshold']
Expand Down

0 comments on commit ebdaf5d

Please sign in to comment.