diff --git a/visual_behavior_glm/GLM_perturbation_tools.py b/visual_behavior_glm/GLM_perturbation_tools.py index b045c4cc..66c63c9b 100644 --- a/visual_behavior_glm/GLM_perturbation_tools.py +++ b/visual_behavior_glm/GLM_perturbation_tools.py @@ -10,43 +10,25 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig=False): - if kernel in ['hits','misses']: - lims=[[-.0005,.002],[-.01,.005]] - if kernel in ['omissions','all-images']: - lims=[[-.0005,.001],[-.015,.03]] - # Plot 1D summaries out1 = gst.strategy_kernel_comparison(weights_beh.query('visual_strategy_session'), run_params, kernel,session_filter=[experience_level]) out2 =gst.strategy_kernel_comparison(weights_beh.query('not visual_strategy_session'), run_params, kernel,session_filter=[experience_level]) - out3 = gst.strategy_kernel_comparison(weights_beh.query('visual_strategy_session'), - run_params, 'all-images',session_filter=[experience_level]) - out4 =gst.strategy_kernel_comparison(weights_beh.query('not visual_strategy_session'), - run_params, 'all-images',session_filter=[experience_level]) - # Unify ylimits and add time markers ylim1 = out1[2].get_ylim() ylim2 = out2[2].get_ylim() - ylim3 = out3[2].get_ylim() - ylim4 = out4[2].get_ylim() - ylim = [np.min([ylim1[0],ylim2[0],ylim3[0],ylim4[0]]), \ - np.max([ylim1[1],ylim2[1],ylim3[1],ylim4[1]])] + ylim = [np.min([ylim1[0],ylim2[0]]), \ + np.max([ylim1[1],ylim2[1]])] out1[2].set_ylim(ylim) out2[2].set_ylim(ylim) - out3[2].set_ylim(ylim) - out4[2].set_ylim(ylim) out1[2].set_ylabel(kernel+' kernel\n(Ca$^{2+}$ events)',fontsize=16) out2[2].set_ylabel(kernel+' kernel\n(Ca$^{2+}$ events)',fontsize=16) - out3[2].set_ylabel('image kernel\n(Ca$^{2+}$ events)',fontsize=16) - out4[2].set_ylabel('image kernel\n(Ca$^{2+}$ events)',fontsize=16) - out1[2].set_title('Visual strategy',fontsize=16) out2[2].set_title('Timing strategy',fontsize=16) - out3[2].set_title('Visual strategy',fontsize=16) - out4[2].set_title('Timing strategy',fontsize=16) + if kernel =='omissions': out1[2].plot(0,ylim[0],'co',zorder=10,clip_on=False) out2[2].plot(0,ylim[0],'co',zorder=10,clip_on=False) @@ -54,8 +36,8 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig out1[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) out2[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) elif kernel == 'misses': - out1[2].plot(0,ylim[0],'rx',zorder=10,clip_on=False) - out2[2].plot(0,ylim[0],'rx',zorder=10,clip_on=False) + out1[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) + out2[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False) else: out1[2].plot(0,ylim[0],'ko',zorder=10,clip_on=False) out2[2].plot(0,ylim[0],'ko',zorder=10,clip_on=False) @@ -65,10 +47,6 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig out2[2].plot(.75,ylim[0],'ko',zorder=10,clip_on=False) out2[2].plot(1.5,ylim[0],'o',color='gray',zorder=10,clip_on=False) - # Make 2D plot - ax = plot_perturbation(weights_beh, run_params, kernel,experience_level=experience_level, - savefig=savefig,lims=lims) - if savefig: filepath = run_params['figure_dir']+\ '/strategy/'+kernel+'_visual_comparison_{}.svg'.format(experience_level) @@ -80,14 +58,6 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig out2[1].savefig(filepath) filepath = run_params['figure_dir']+\ '/strategy/images_visual_comparison_{}.svg'.format(experience_level) - print('Figure saved to: '+filepath) - out3[1].savefig(filepath) - filepath = run_params['figure_dir']+\ - '/strategy/images_timing_comparison_{}.svg'.format(experience_level) - print('Figure saved to: '+filepath) - out4[1].savefig(filepath) - - return ax def get_kernel_averages(weights_df, run_params, kernel, drop_threshold=0, session_filter=['Familiar','Novel 1','Novel >1'],equipment_filter="all", @@ -307,20 +277,34 @@ def plot_iterative_ch(ax,df,color,show_steps=False): else: ax.fill(points[hull.vertices,0],points[hull.vertices,1],color=color) -def get_points(df): +def get_points(df,ellipse=True): ''' Return a 2D array of error points for two subsequent points in df ''' x=[] y=[] for index, row in df.iterrows(): - this_x= [row.x1,row.x,row.x2,row.x] - this_y= [row.y,row.y2,row.y,row.y1] + if ellipse: + r1 = (row.x2-row.x1)/2 + r2 = (row.y2-row.y1)/2 + n=20 + this_x = [ + r1*np.cos(theta)+row.x + for theta in (np.pi*2 * i/n for i in range(n)) + ] + this_y = [ + r2*np.sin(theta)+row.y + for theta in (np.pi*2 * i/n for i in range(n)) + ] + else: + this_x= [row.x1,row.x,row.x2,row.x] + this_y= [row.y,row.y2,row.y,row.y1] x=x+this_x y=y+this_y return np.array([x,y]).T -def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True): +def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True, + x='Slc17a7-IRES2-Cre',y='y'): ''' A demonstration that shows the iterative convex hull solution ''' @@ -336,19 +320,26 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True): pi3 = len(time) colors = gvt.project_colors() - fig, ax = plt.subplots(figsize=(4,3.5)) - df = get_error(Fvisual) + fig, ax = plt.subplots(figsize=(5,3.5)) + df = get_error(Fvisual,x=x,y=y) df = df.loc[0:pi3] plot_iterative_ch(ax,df,'lightgray',show_steps=show_steps) if show_steps: - ax.errorbar(Fvisual['Slc17a7-IRES2-Cre'][0:pi3],Fvisual['y'][0:pi3], - xerr=Fvisual['Slc17a7-IRES2-Cre_sem'][0:pi3], - yerr=Fvisual['y_sem'][0:pi3],color='gray',alpha=.5) - ax.plot(Fvisual['Slc17a7-IRES2-Cre'][0:pi3], Fvisual['y'][0:pi3], + ax.errorbar(Fvisual[x][0:pi3],Fvisual[y][0:pi3], + xerr=Fvisual[x+'_sem'][0:pi3], + yerr=Fvisual[y+'_sem'][0:pi3],color='gray',alpha=.5) + ax.plot(Fvisual[x][0:pi3], Fvisual[y][0:pi3], color=colors['visual'],label='Visual',linewidth=3) - - ax.set_ylabel('Vip - Sst',fontsize=16) - ax.set_xlabel('Exc',fontsize=16) + + mapper = { + 'Slc17a7-IRES2-Cre':'Exc', + 'Sst-IRES-Cre':'Sst', + 'Vip-IRES-Cre':'Vip', + 'y':'Vip - Sst', + } + + ax.set_ylabel(mapper[y],fontsize=16) + ax.set_xlabel(mapper[x],fontsize=16) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.xaxis.set_tick_params(labelsize=12) @@ -357,6 +348,7 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True): ax.legend() ax.axhline(0,color='k',linestyle='--',alpha=.25) ax.axvline(0,color='k',linestyle='--',alpha=.25) + plt.tight_layout() def get_perturbation(weights_df, run_params, kernel,experience_level="Familiar"): visual = get_kernel_averages(weights_df.query('visual_strategy_session'), @@ -381,7 +373,17 @@ def plot_multiple(weights_df, run_params,savefig=False): return def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar", - savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False): + savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False, + x = 'Slc17a7-IRES2-Cre',y='Vip-IRES-Cre'): + + limit_list = { + 'Slc17a7-IRES2-Cre':[-.0005,0.002], + 'Sst-IRES-Cre':[-0.006,0.0125], + 'Vip-IRES-Cre':[-0.005,0.025], + 'y':[0,0.01], + } + lims=[limit_list[x], limit_list[y]] + visual, timing = get_perturbation(weights_df, run_params,kernel,experience_level) time = visual['time'] offset = 0 @@ -395,48 +397,54 @@ def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar" pi3 = len(time) if show_steps: - demonstrate_iterative_ch(visual,kernel) + demonstrate_iterative_ch(visual,kernel,x=x,y=y) colors = gvt.project_colors() if ax is None: fig, ax = plt.subplots(1,1,sharey=True,sharex=True,figsize=(4.5,4)) - df = get_error(visual).loc[0:pi3] + df = get_error(visual,x=x,y=y).loc[0:pi3] plot_iterative_ch(ax,df,'lightgray') - df = get_error(timing).loc[0:pi3] + df = get_error(timing,x=x,y=y).loc[0:pi3] plot_iterative_ch(ax,df,'lightgray') - ax.plot(visual['Slc17a7-IRES2-Cre'][0:pi3], visual['y'][0:pi3], + ax.plot(visual[x][0:pi3], visual[y][0:pi3], color=colors['visual'],label='Visual',linewidth=3) - ax.plot(timing['Slc17a7-IRES2-Cre'][0:pi3], timing['y'][0:pi3], + ax.plot(timing[x][0:pi3], timing[y][0:pi3], color=colors['timing'],label='Timing',linewidth=3) if kernel =='omissions': - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'co') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'co') + ax.plot(visual[x][0],visual[y][0],'co') + ax.plot(timing[x][0],timing[y][0],'co') elif kernel == 'hits': - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ro') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ro') + ax.plot(visual[x][0],visual[y][0],'ro') + ax.plot(timing[x][0],timing[y][0],'ro') elif kernel == 'misses': - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'rx') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'rx') + ax.plot(visual[x][0],visual[y][0],'rx') + ax.plot(timing[x][0],timing[y][0],'rx') else: - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ko') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ko') + ax.plot(visual[x][0],visual[y][0],'ko') + ax.plot(timing[x][0],timing[y][0],'ko') if multiimage: - ax.plot(visual['Slc17a7-IRES2-Cre'][pi],visual['y'][pi],'ko') - ax.plot(timing['Slc17a7-IRES2-Cre'][pi],timing['y'][pi],'ko') - ax.plot(visual['Slc17a7-IRES2-Cre'][pi2],visual['y'][pi2],'o',color='gray') - ax.plot(timing['Slc17a7-IRES2-Cre'][pi2],timing['y'][pi2],'o',color='gray') - + ax.plot(visual[x][pi],visual[y][pi],'ko') + ax.plot(timing[x][pi],timing[y][pi],'ko') + ax.plot(visual[x][pi2],visual[y][pi2],'o',color='gray') + ax.plot(timing[x][pi2],timing[y][pi2],'o',color='gray') + + mapper = { + 'Slc17a7-IRES2-Cre':'Exc', + 'Sst-IRES-Cre':'Sst', + 'Vip-IRES-Cre':'Vip', + 'y':'Vip - Sst', + } if col1: - ax.set_ylabel(experience_level+'\nVip - Sst',fontsize=16) + ax.set_ylabel(experience_level+'\n'+mapper[y],fontsize=16) elif not multi: - ax.set_ylabel('Vip - Sst',fontsize=16) + ax.set_ylabel(mapper[y],fontsize=16) if row1: - ax.set_xlabel(kernel+'\nExc',fontsize=16) + ax.set_xlabel(kernel+'\n'+mapper[x],fontsize=16) elif not multi: - ax.set_xlabel('Exc',fontsize=16) + ax.set_xlabel(mapper[x],fontsize=16) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.set_xticks([0,.001, .002]) @@ -454,7 +462,7 @@ def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar" if savefig: filepath = run_params['figure_dir']+\ - '/strategy/'+kernel+'_strategy_perturbation_{}.svg'.format(experience_level) + '/strategy/'+kernel+'_strategy_perturbation_{}_{}_{}.svg'.format(experience_level,x,y) print('Figure saved to: '+filepath) plt.savefig(filepath) diff --git a/visual_behavior_glm/GLM_strategy_tools.py b/visual_behavior_glm/GLM_strategy_tools.py index 009cf654..1856b852 100644 --- a/visual_behavior_glm/GLM_strategy_tools.py +++ b/visual_behavior_glm/GLM_strategy_tools.py @@ -95,6 +95,11 @@ def kernels_by_cre(weights_beh, run_params, kernel='omissions', savefig=False, sharey=False, depth_filter=[0,1000]): cres = ['Vip-IRES-Cre','Sst-IRES-Cre','Slc17a7-IRES2-Cre'] + limit_list = { + 'Slc17a7-IRES2-Cre':[-.0005,0.002], + 'Sst-IRES-Cre':[-.006,0.0125], + 'Vip-IRES-Cre':[-0.005,0.025], + } for dex, cre in enumerate(cres): height = 4 width=8 @@ -102,12 +107,17 @@ def kernels_by_cre(weights_beh, run_params, kernel='omissions', post_horz_offset = 2.5 vertical_offset = .75 fig = plt.figure(figsize=(width,height)) - if kernel == 'all_images': - h = [Size.Fixed(pre_horz_offset),\ - Size.Fixed((width-pre_horz_offset-post_horz_offset)/3*.75)] - else: - h = [Size.Fixed(pre_horz_offset),\ - Size.Fixed(width-pre_horz_offset-post_horz_offset)] + + duration = run_params['kernels'][kernel]['length'] + h = [Size.Fixed(pre_horz_offset),\ + Size.Fixed((width-pre_horz_offset-post_horz_offset)/3*duration)] + + #if kernel == 'all-images': + # h = [Size.Fixed(pre_horz_offset),\ + # Size.Fixed((width-pre_horz_offset-post_horz_offset)/3*.75)] + #else: + # h = [Size.Fixed(pre_horz_offset),\ + # Size.Fixed(width-pre_horz_offset-post_horz_offset)] v = [Size.Fixed(vertical_offset),Size.Fixed(height-vertical_offset-.5)] divider = Divider(fig, (0,0,1,1),h,v,aspect=False) ax = fig.add_axes(divider.get_position(),\ @@ -124,7 +134,8 @@ def kernels_by_cre(weights_beh, run_params, kernel='omissions', ax.set_title(string_mapper(cre),fontsize=16) ax.set_ylabel(kernel+' weights\n(Ca$^{2+}$ events)',fontsize=16) - ylim=ax.get_ylim() + #ylim=ax.get_ylim() + ylim = limit_list[cre] ax.set_ylim(ylim) if kernel =='omissions': out[2].plot(0,ylim[0],'co',zorder=10,clip_on=False) diff --git a/visual_behavior_glm/strategy_paper_figure_script.py b/visual_behavior_glm/strategy_paper_figure_script.py index fc3f3019..b9c3be8d 100644 --- a/visual_behavior_glm/strategy_paper_figure_script.py +++ b/visual_behavior_glm/strategy_paper_figure_script.py @@ -2,8 +2,10 @@ import visual_behavior_glm.PSTH as psth import visual_behavior_glm.GLM_params as glm_params import visual_behavior_glm.GLM_fit_tools as gft +import visual_behavior_glm.GLM_fit_dev as gfd import visual_behavior_glm.GLM_schematic_plots as gsm import visual_behavior_glm.GLM_perturbation_tools as gpt +import visual_behavior_glm.GLM_strategy_tools as gst import matplotlib.pyplot as plt plt.ion() from importlib import reload @@ -135,8 +137,28 @@ ## Fig S5 - GLM Supplement ################################################################################ +GLM_VERSION = '24_events_all_L2_optimize_by_session' +run_params, results, results_pivoted, weights_df = gfd.get_analysis_dfs(GLM_VERSION) +weights_beh = gst.add_behavior_session_metrics(weights_df, summary_df) + +# Plot kernels over time, compare cell types +gpt.analysis(weights_beh, run_params, 'all-images') gpt.analysis(weights_beh, run_params, 'omissions') -gst.kernels_by_cre(weights_beh, run_params) +gpt.analysis(weights_beh, run_params, 'hits') +gpt.analysis(weights_beh, run_params, 'misses') + +# Plot kernels over time, compare strategies +gst.kernels_by_cre(weights_beh, run_params, 'all-images') +gst.kernels_by_cre(weights_beh, run_params, 'omissions') +gst.kernels_by_cre(weights_beh, run_params, 'hits') +gst.kernels_by_cre(weights_beh, run_params, 'misses') + +# Plot state space plots +gpt.plot_perturbation(weights_beh, run_params, 'all-images') +gpt.plot_perturbation(weights_beh, run_params, 'omissions') +gpt.plot_perturbation(weights_beh, run_params, 'hits') +gpt.plot_perturbation(weights_beh, run_params, 'misses') + ## Fig. 6 Engagement PSTHs