diff --git a/visual_behavior_glm/GLM_perturbation_tools.py b/visual_behavior_glm/GLM_perturbation_tools.py index f57c44da..07bc210a 100644 --- a/visual_behavior_glm/GLM_perturbation_tools.py +++ b/visual_behavior_glm/GLM_perturbation_tools.py @@ -277,20 +277,52 @@ def plot_iterative_ch(ax,df,color,show_steps=False): else: ax.fill(points[hull.vertices,0],points[hull.vertices,1],color=color) -def get_points(df): +def get_points(df,ellipse=True): ''' Return a 2D array of error points for two subsequent points in df ''' x=[] y=[] for index, row in df.iterrows(): - this_x= [row.x1,row.x,row.x2,row.x] - this_y= [row.y,row.y2,row.y,row.y1] + if ellipse: + r1 = (row.x2-row.x1)/2 + r2 = (row.y2-row.y1)/2 + n=20 + this_x = [ + r1*np.cos(theta)+row.x + for theta in (np.pi*2 * i/n for i in range(n)) + ] + this_y = [ + r2*np.sin(theta)+row.y + for theta in (np.pi*2 * i/n for i in range(n)) + ] + else: + this_x= [row.x1,row.x,row.x2,row.x] + this_y= [row.y,row.y2,row.y,row.y1] x=x+this_x y=y+this_y return np.array([x,y]).T -def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True): +def get_ellipse_points(df): + plt.figure() + for index, row in df.iterrows(): + plt.plot(row.x,row.y,'ro') + plt.plot([row.x1,row.x2],[row.y,row.y],'r-') + plt.plot([row.x,row.x],[row.y1,row.y2],'r-') + r1 = (row.x2-row.x1)/2 + r2 = (row.y2-row.y1)/2 + n=20 + xpoints = [ + r1*np.cos(theta)+row.x + for theta in (np.pi*2 * i/n for i in range(n)) + ] + ypoints = [ + r2*np.sin(theta)+row.y + for theta in (np.pi*2 * i/n for i in range(n)) + ] + plt.plot(xpoints, ypoints, 'b-') +def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True, + x='Slc17a7-IRES2-Cre',y='y'): ''' A demonstration that shows the iterative convex hull solution ''' @@ -306,19 +338,26 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True): pi3 = len(time) colors = gvt.project_colors() - fig, ax = plt.subplots(figsize=(4,3.5)) - df = get_error(Fvisual) + fig, ax = plt.subplots(figsize=(5,3.5)) + df = get_error(Fvisual,x=x,y=y) df = df.loc[0:pi3] plot_iterative_ch(ax,df,'lightgray',show_steps=show_steps) if show_steps: - ax.errorbar(Fvisual['Slc17a7-IRES2-Cre'][0:pi3],Fvisual['y'][0:pi3], - xerr=Fvisual['Slc17a7-IRES2-Cre_sem'][0:pi3], - yerr=Fvisual['y_sem'][0:pi3],color='gray',alpha=.5) - ax.plot(Fvisual['Slc17a7-IRES2-Cre'][0:pi3], Fvisual['y'][0:pi3], + ax.errorbar(Fvisual[x][0:pi3],Fvisual[y][0:pi3], + xerr=Fvisual[x+'_sem'][0:pi3], + yerr=Fvisual[y+'_sem'][0:pi3],color='gray',alpha=.5) + ax.plot(Fvisual[x][0:pi3], Fvisual[y][0:pi3], color=colors['visual'],label='Visual',linewidth=3) - - ax.set_ylabel('Vip - Sst',fontsize=16) - ax.set_xlabel('Exc',fontsize=16) + + mapper = { + 'Slc17a7-IRES2-Cre':'Exc', + 'Sst-IRES-Cre':'Sst', + 'Vip-IRES-Cre':'Vip', + 'y':'Vip - Sst', + } + + ax.set_ylabel(mapper[y],fontsize=16) + ax.set_xlabel(mapper[x],fontsize=16) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.xaxis.set_tick_params(labelsize=12) @@ -327,6 +366,7 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True): ax.legend() ax.axhline(0,color='k',linestyle='--',alpha=.25) ax.axvline(0,color='k',linestyle='--',alpha=.25) + plt.tight_layout() def get_perturbation(weights_df, run_params, kernel,experience_level="Familiar"): visual = get_kernel_averages(weights_df.query('visual_strategy_session'), @@ -351,7 +391,8 @@ def plot_multiple(weights_df, run_params,savefig=False): return def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar", - savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False): + savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False, + x = 'Slc17a7-IRES2-Cre',y='Vip-IRES-Cre'): visual, timing = get_perturbation(weights_df, run_params,kernel,experience_level) time = visual['time'] offset = 0 @@ -365,48 +406,54 @@ def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar" pi3 = len(time) if show_steps: - demonstrate_iterative_ch(visual,kernel) + demonstrate_iterative_ch(visual,kernel,x=x,y=y) colors = gvt.project_colors() if ax is None: fig, ax = plt.subplots(1,1,sharey=True,sharex=True,figsize=(4.5,4)) - df = get_error(visual).loc[0:pi3] + df = get_error(visual,x=x,y=y).loc[0:pi3] plot_iterative_ch(ax,df,'lightgray') - df = get_error(timing).loc[0:pi3] + df = get_error(timing,x=x,y=y).loc[0:pi3] plot_iterative_ch(ax,df,'lightgray') - ax.plot(visual['Slc17a7-IRES2-Cre'][0:pi3], visual['y'][0:pi3], + ax.plot(visual[x][0:pi3], visual[y][0:pi3], color=colors['visual'],label='Visual',linewidth=3) - ax.plot(timing['Slc17a7-IRES2-Cre'][0:pi3], timing['y'][0:pi3], + ax.plot(timing[x][0:pi3], timing[y][0:pi3], color=colors['timing'],label='Timing',linewidth=3) if kernel =='omissions': - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'co') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'co') + ax.plot(visual[x][0],visual[y][0],'co') + ax.plot(timing[x][0],timing[y][0],'co') elif kernel == 'hits': - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ro') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ro') + ax.plot(visual[x][0],visual[y][0],'ro') + ax.plot(timing[x][0],timing[y][0],'ro') elif kernel == 'misses': - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'rx') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'rx') + ax.plot(visual[x][0],visual[y][0],'rx') + ax.plot(timing[x][0],timing[y][0],'rx') else: - ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ko') - ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ko') + ax.plot(visual[x][0],visual[y][0],'ko') + ax.plot(timing[x][0],timing[y][0],'ko') if multiimage: - ax.plot(visual['Slc17a7-IRES2-Cre'][pi],visual['y'][pi],'ko') - ax.plot(timing['Slc17a7-IRES2-Cre'][pi],timing['y'][pi],'ko') - ax.plot(visual['Slc17a7-IRES2-Cre'][pi2],visual['y'][pi2],'o',color='gray') - ax.plot(timing['Slc17a7-IRES2-Cre'][pi2],timing['y'][pi2],'o',color='gray') - + ax.plot(visual[x][pi],visual[y][pi],'ko') + ax.plot(timing[x][pi],timing[y][pi],'ko') + ax.plot(visual[x][pi2],visual[y][pi2],'o',color='gray') + ax.plot(timing[x][pi2],timing[y][pi2],'o',color='gray') + + mapper = { + 'Slc17a7-IRES2-Cre':'Exc', + 'Sst-IRES-Cre':'Sst', + 'Vip-IRES-Cre':'Vip', + 'y':'Vip - Sst', + } if col1: - ax.set_ylabel(experience_level+'\nVip - Sst',fontsize=16) + ax.set_ylabel(experience_level+'\n'+mapper[y],fontsize=16) elif not multi: - ax.set_ylabel('Vip - Sst',fontsize=16) + ax.set_ylabel(mapper[y],fontsize=16) if row1: - ax.set_xlabel(kernel+'\nExc',fontsize=16) + ax.set_xlabel(kernel+'\n'+mapper[x],fontsize=16) elif not multi: - ax.set_xlabel('Exc',fontsize=16) + ax.set_xlabel(mapper[x],fontsize=16) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.set_xticks([0,.001, .002]) diff --git a/visual_behavior_glm/strategy_paper_figure_script.py b/visual_behavior_glm/strategy_paper_figure_script.py index de2f3066..017b4d15 100644 --- a/visual_behavior_glm/strategy_paper_figure_script.py +++ b/visual_behavior_glm/strategy_paper_figure_script.py @@ -153,6 +153,14 @@ gst.kernels_by_cre(weights_beh, run_params, 'hits') gst.kernels_by_cre(weights_beh, run_params, 'misses') +# Plot state space plots +gpt.plot_perturbation(weights_beh, run_params. 'all-images') +gpt.plot_perturbation(weights_beh, run_params, 'omissions') +gpt.plot_perturbation(weights_beh, run_params, 'hits') +gpt.plot_perturbation(weights_beh, run_params, 'misses') + + + ## Fig. 6 Engagement PSTHs ################################################################################ dfs = psth.get_figure_4_psth(data='events')