Skip to content

Latest commit

 

History

History
92 lines (74 loc) · 3.24 KB

README.md

File metadata and controls

92 lines (74 loc) · 3.24 KB

Inferring stochastic low-rank RNNs with variational sequential monte-carlo / particle filtering

Code accompanying Inferring stochastic low-rank recurrent neural networks from neural data

Matthijs Pals, A Erdem Sağtekin, Felix Pei, Manuel Gloeckler and Jakob H Macke, 2024

Quick start:

cd smc_rnns
conda env create -f smc_rnn_env.yml
activate smc_rnn_env
python run train_scripts/eeg/run_eeg.py

Finding fixed points

We included code for finding fixed points (and cycles) in piecewise linear low-rank RNNs. The fixed_points folder contains scripts for both our 'semi'-analytic method, as well as a modified version of SCYFI (Eisenmann et al. 2023, GNU General Public Licence), where the search-space can additionally be constrained to the sub-regions that can contain fixed points.

Example usage in :

generate_figures/Fig_8_find_fixed_points.ipynb

Student-teacher setups

Train the teacher and student networks using:

train_scripts/student_teacher/train_student_teacher_continuous.ipynb
train_scripts/student_teacher/train_student_teacher_poisson.ipynb
train_scripts/student_teacher/train_student_teacher_conditioning.ipynb

Generate figure:

generate_figures/Fig_3_create_figure.ipynb

EEG

We used openly accessible electroencephalogram (EEG) data from Schalk et al. 2004, available from Physionet (Moody et al. 2000; ODC-BY licence). This repo includes preprocessed data from session S001R01.

Train models using:

python run train_scripts/eeg/run_eeg.py

Generate figure and table:

generate_figures/Fig_4_plot_EEG.ipynb
generate_figures/Table_1_calc_stats.ipynb

Rodent hippocampal datasets

To run these experiments you first need to obtain the data from CRCNS: hc2 (Mizuseki et al. 2009) and hc11 (Grosmark et al. 2016)

Data can be preprocessed using the notebooks:

train_scripts/hpc/hpc2_spike_preprocessing.ipynb
train_scripts/hpc/hpc1_spike_preprocessing.ipynb
train_scripts/hpc/hpc11_lfp_preprocessing.ipynb

Train models using:

wandb sweep train_scripts/hpc/sweep.yaml
wandb agent [name]/smc_rnns-train_scripts_hpc/[sweep-id]

Generate figures:

generate_figures/Fig_5_hpc2.ipynb
generate_figures/Fig_6_S5_hpc11.ipynb

Macaque Reach

To run these experiments you first need to obtain the MC_Maze dataset from the Neural Latents Benchmark (Pei et al. 2021 CC-BY-4.0 licence)

We need to process the dataset and either extract spikes + context input (first line) or just the spikes (second line, for the NLB evaluation) using:

python train_scripts/macaque_reach/make_tensors_conditioning.py --binsize 20
python train_scripts/macaque_reach/make_tensors_nlb.py --binsize 20

Train models using:

python train_scripts/macaque_reach/train_single_conditioning.py --run_name reach_conditioning
python train_scripts/macaque_reach/train_single_nlb.py --run_name reach_nlb 'dataset=mc_maze_20ms_val_nlb'

Generate figures:

generate_figures/Fig_7_Maze.ipynb

Feel free to reach out with any questions!