diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index 374d9098..205e09eb 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -166,6 +166,22 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')): #Xg = torch.from_numpy(Xd).to(dev) vtot = (Xg**2).sum(1) + n1 = vtot.shape[0] + if n1 > 2**24: + # Need to subsample v2, torch.multinomial doesn't allow more than 2**24 + # elements. We're just using this to sample some spikes, so it's fine to + # not use all of them. + n2 = n1 - 2**24 # number of spikes to remove before sampling + remove = np.round(np.linspace(0, n1-1, n2)).astype(int) + idx = np.ones(n1, dtype=bool) + idx[remove] = False + # Also need to map the indices from the subset back to indices for + # the full tensor. + rev_idx = idx.nonzero()[0] + subsample = True + else: + subsample = False + torch.manual_seed(seed) np.random.seed(seed) @@ -176,8 +192,11 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')): iclust = torch.zeros((NN,), dtype = torch.int, device = device) for j in range(niter): - v2 = torch.relu(vtot - vexp0) - isamp = torch.multinomial(v2, ntry) + v2 = torch.relu(vtot - vexp0) + if subsample: + isamp = rev_idx[torch.multinomial(v2[idx], ntry)] + else: + isamp = torch.multinomial(v2, ntry) Xc = Xg[isamp] vexp = 2 * Xg @ Xc.T - (Xc**2).sum(1)