Skip to content

Commit

Permalink
Added fix for kmeans_plusplus for very large number of spikes
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Aug 13, 2024
1 parent a3d30d3 commit a6ba59a
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,22 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')):
#Xg = torch.from_numpy(Xd).to(dev)
vtot = (Xg**2).sum(1)

n1 = vtot.shape[0]
if n1 > 2**24:
# Need to subsample v2, torch.multinomial doesn't allow more than 2**24
# elements. We're just using this to sample some spikes, so it's fine to
# not use all of them.
n2 = n1 - 2**24 # number of spikes to remove before sampling
remove = np.round(np.linspace(0, n1-1, n2)).astype(int)
idx = np.ones(n1, dtype=bool)
idx[remove] = False
# Also need to map the indices from the subset back to indices for
# the full tensor.
rev_idx = idx.nonzero()[0]
subsample = True
else:
subsample = False

torch.manual_seed(seed)
np.random.seed(seed)

Expand All @@ -176,8 +192,11 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')):

iclust = torch.zeros((NN,), dtype = torch.int, device = device)
for j in range(niter):
v2 = torch.relu(vtot - vexp0)
isamp = torch.multinomial(v2, ntry)
v2 = torch.relu(vtot - vexp0)
if subsample:
isamp = rev_idx[torch.multinomial(v2[idx], ntry)]
else:
isamp = torch.multinomial(v2, ntry)

Xc = Xg[isamp]
vexp = 2 * Xg @ Xc.T - (Xc**2).sum(1)
Expand Down

0 comments on commit a6ba59a

Please sign in to comment.