diff --git a/visual_behavior_glm/GLM_perturbation_tools.py b/visual_behavior_glm/GLM_perturbation_tools.py index b045c4cc..f57c44da 100644 --- a/visual_behavior_glm/GLM_perturbation_tools.py +++ b/visual_behavior_glm/GLM_perturbation_tools.py @@ -10,43 +10,25 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig=False): - if kernel in ['hits','misses']: - lims=[[-.0005,.002],[-.01,.005]] - if kernel in ['omissions','all-images']: - lims=[[-.0005,.001],[-.015,.03]] - # Plot 1D summaries out1 = gst.strategy_kernel_comparison(weights_beh.query('visual_strategy_session'), run_params, kernel,session_filter=[experience_level]) out2 =gst.strategy_kernel_comparison(weights_beh.query('not visual_strategy_session'), run_params, kernel,session_filter=[experience_level]) - out3 = gst.strategy_kernel_comparison(weights_beh.query('visual_strategy_session'), - run_params, 'all-images',session_filter=[experience_level]) - out4 =gst.strategy_kernel_comparison(weights_beh.query('not visual_strategy_session'), - run_params, 'all-images',session_filter=[experience_level]) - # Unify ylimits and add time markers ylim1 = out1[2].get_ylim() ylim2 = out2[2].get_ylim() - ylim3 = out3[2].get_ylim() - ylim4 = out4[2].get_ylim() - ylim = [np.min([ylim1[0],ylim2[0],ylim3[0],ylim4[0]]), \ - np.max([ylim1[1],ylim2[1],ylim3[1],ylim4[1]])] + ylim = [np.min([ylim1[0],ylim2[0]]), \ + np.max([ylim1[1],ylim2[1]])] out1[2].set_ylim(ylim) out2[2].set_ylim(ylim) - out3[2].set_ylim(ylim) - out4[2].set_ylim(ylim) out1[2].set_ylabel(kernel+' kernel\n(Ca$^{2+}$ events)',fontsize=16) out2[2].set_ylabel(kernel+' kernel\n(Ca$^{2+}$ events)',fontsize=16) - out3[2].set_ylabel('image kernel\n(Ca$^{2+}$ events)',fontsize=16) - out4[2].set_ylabel('image kernel\n(Ca$^{2+}$ events)',fontsize=16) - out1[2].set_title('Visual strategy',fontsize=16) out2[2].set_title('Timing strategy',fontsize=16) - out3[2].set_title('Visual strategy',fontsize=16) - out4[2].set_title('Timing strategy',fontsize=16) + if kernel =='omissions': out1[2].plot(0,ylim[0],'co',zorder=10,clip_on=False) out2[2].plot(0,ylim[0],'co',zorder=10,clip_on=False) @@ -54,8 +36,8 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig out1[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) out2[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) elif kernel == 'misses': - out1[2].plot(0,ylim[0],'rx',zorder=10,clip_on=False) - out2[2].plot(0,ylim[0],'rx',zorder=10,clip_on=False) + out1[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) + out2[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) else: out1[2].plot(0,ylim[0],'ko',zorder=10,clip_on=False) out2[2].plot(0,ylim[0],'ko',zorder=10,clip_on=False) @@ -65,10 +47,6 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig out2[2].plot(.75,ylim[0],'ko',zorder=10,clip_on=False) out2[2].plot(1.5,ylim[0],'o',color='gray',zorder=10,clip_on=False) - # Make 2D plot - ax = plot_perturbation(weights_beh, run_params, kernel,experience_level=experience_level, - savefig=savefig,lims=lims) - if savefig: filepath = run_params['figure_dir']+\ '/strategy/'+kernel+'_visual_comparison_{}.svg'.format(experience_level) @@ -80,14 +58,6 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig out2[1].savefig(filepath) filepath = run_params['figure_dir']+\ '/strategy/images_visual_comparison_{}.svg'.format(experience_level) - print('Figure saved to: '+filepath) - out3[1].savefig(filepath) - filepath = run_params['figure_dir']+\ - '/strategy/images_timing_comparison_{}.svg'.format(experience_level) - print('Figure saved to: '+filepath) - out4[1].savefig(filepath) - - return ax def get_kernel_averages(weights_df, run_params, kernel, drop_threshold=0, session_filter=['Familiar','Novel 1','Novel >1'],equipment_filter="all", diff --git a/visual_behavior_glm/strategy_paper_figure_script.py b/visual_behavior_glm/strategy_paper_figure_script.py index ac92e0d3..15cf5bbf 100644 --- a/visual_behavior_glm/strategy_paper_figure_script.py +++ b/visual_behavior_glm/strategy_paper_figure_script.py @@ -140,7 +140,14 @@ GLM_VERSION = '24_events_all_L2_optimize_by_session' run_params, results, results_pivoted, weights_df = gfd.get_analysis_dfs(GLM_VERSION) weights_beh = gst.add_behavior_session_metrics(weights_df, summary_df) + +# Plot kernels over time, compare cell types +gpt.analysis(weights_beh, run_params, 'all-images') gpt.analysis(weights_beh, run_params, 'omissions') +gpt.analysis(weights_beh, run_params, 'hits') +gpt.analysis(weights_beh, run_params, 'misses') + +# Plot kernels over time, compare strategies gst.kernels_by_cre(weights_beh, run_params)