Skip to content

Commit

Permalink
Merge pull request #585 from AllenInstitute/analysis
Browse files Browse the repository at this point in the history
GLM supplemental figure
  • Loading branch information
alexpiet authored Feb 23, 2023
2 parents a513ded + 18efb1f commit 9c1c256
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 80 deletions.
152 changes: 80 additions & 72 deletions visual_behavior_glm/GLM_perturbation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,34 @@

def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig=False):

if kernel in ['hits','misses']:
lims=[[-.0005,.002],[-.01,.005]]
if kernel in ['omissions','all-images']:
lims=[[-.0005,.001],[-.015,.03]]

# Plot 1D summaries
out1 = gst.strategy_kernel_comparison(weights_beh.query('visual_strategy_session'),
run_params, kernel,session_filter=[experience_level])
out2 =gst.strategy_kernel_comparison(weights_beh.query('not visual_strategy_session'),
run_params, kernel,session_filter=[experience_level])

out3 = gst.strategy_kernel_comparison(weights_beh.query('visual_strategy_session'),
run_params, 'all-images',session_filter=[experience_level])
out4 =gst.strategy_kernel_comparison(weights_beh.query('not visual_strategy_session'),
run_params, 'all-images',session_filter=[experience_level])

# Unify ylimits and add time markers
ylim1 = out1[2].get_ylim()
ylim2 = out2[2].get_ylim()
ylim3 = out3[2].get_ylim()
ylim4 = out4[2].get_ylim()
ylim = [np.min([ylim1[0],ylim2[0],ylim3[0],ylim4[0]]), \
np.max([ylim1[1],ylim2[1],ylim3[1],ylim4[1]])]
ylim = [np.min([ylim1[0],ylim2[0]]), \
np.max([ylim1[1],ylim2[1]])]

out1[2].set_ylim(ylim)
out2[2].set_ylim(ylim)
out3[2].set_ylim(ylim)
out4[2].set_ylim(ylim)
out1[2].set_ylabel(kernel+' kernel\n(Ca$^{2+}$ events)',fontsize=16)
out2[2].set_ylabel(kernel+' kernel\n(Ca$^{2+}$ events)',fontsize=16)
out3[2].set_ylabel('image kernel\n(Ca$^{2+}$ events)',fontsize=16)
out4[2].set_ylabel('image kernel\n(Ca$^{2+}$ events)',fontsize=16)

out1[2].set_title('Visual strategy',fontsize=16)
out2[2].set_title('Timing strategy',fontsize=16)
out3[2].set_title('Visual strategy',fontsize=16)
out4[2].set_title('Timing strategy',fontsize=16)

if kernel =='omissions':
out1[2].plot(0,ylim[0],'co',zorder=10,clip_on=False)
out2[2].plot(0,ylim[0],'co',zorder=10,clip_on=False)
elif kernel =='hits':
out1[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False)
out2[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False)
elif kernel == 'misses':
out1[2].plot(0,ylim[0],'rx',zorder=10,clip_on=False)
out2[2].plot(0,ylim[0],'rx',zorder=10,clip_on=False)
out1[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False)
out2[2].plot(0,ylim[0],'ro',zorder=10,clip_on=False)
else:
out1[2].plot(0,ylim[0],'ko',zorder=10,clip_on=False)
out2[2].plot(0,ylim[0],'ko',zorder=10,clip_on=False)
Expand All @@ -65,10 +47,6 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig
out2[2].plot(.75,ylim[0],'ko',zorder=10,clip_on=False)
out2[2].plot(1.5,ylim[0],'o',color='gray',zorder=10,clip_on=False)

# Make 2D plot
ax = plot_perturbation(weights_beh, run_params, kernel,experience_level=experience_level,
savefig=savefig,lims=lims)

if savefig:
filepath = run_params['figure_dir']+\
'/strategy/'+kernel+'_visual_comparison_{}.svg'.format(experience_level)
Expand All @@ -80,14 +58,6 @@ def analysis(weights_beh, run_params, kernel,experience_level='Familiar',savefig
out2[1].savefig(filepath)
filepath = run_params['figure_dir']+\
'/strategy/images_visual_comparison_{}.svg'.format(experience_level)
print('Figure saved to: '+filepath)
out3[1].savefig(filepath)
filepath = run_params['figure_dir']+\
'/strategy/images_timing_comparison_{}.svg'.format(experience_level)
print('Figure saved to: '+filepath)
out4[1].savefig(filepath)

return ax

def get_kernel_averages(weights_df, run_params, kernel, drop_threshold=0,
session_filter=['Familiar','Novel 1','Novel >1'],equipment_filter="all",
Expand Down Expand Up @@ -307,20 +277,34 @@ def plot_iterative_ch(ax,df,color,show_steps=False):
else:
ax.fill(points[hull.vertices,0],points[hull.vertices,1],color=color)

def get_points(df):
def get_points(df,ellipse=True):
'''
Return a 2D array of error points for two subsequent points in df
'''
x=[]
y=[]
for index, row in df.iterrows():
this_x= [row.x1,row.x,row.x2,row.x]
this_y= [row.y,row.y2,row.y,row.y1]
if ellipse:
r1 = (row.x2-row.x1)/2
r2 = (row.y2-row.y1)/2
n=20
this_x = [
r1*np.cos(theta)+row.x
for theta in (np.pi*2 * i/n for i in range(n))
]
this_y = [
r2*np.sin(theta)+row.y
for theta in (np.pi*2 * i/n for i in range(n))
]
else:
this_x= [row.x1,row.x,row.x2,row.x]
this_y= [row.y,row.y2,row.y,row.y1]
x=x+this_x
y=y+this_y
return np.array([x,y]).T

def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True):
def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True,
x='Slc17a7-IRES2-Cre',y='y'):
'''
A demonstration that shows the iterative convex hull solution
'''
Expand All @@ -336,19 +320,26 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True):
pi3 = len(time)

colors = gvt.project_colors()
fig, ax = plt.subplots(figsize=(4,3.5))
df = get_error(Fvisual)
fig, ax = plt.subplots(figsize=(5,3.5))
df = get_error(Fvisual,x=x,y=y)
df = df.loc[0:pi3]
plot_iterative_ch(ax,df,'lightgray',show_steps=show_steps)
if show_steps:
ax.errorbar(Fvisual['Slc17a7-IRES2-Cre'][0:pi3],Fvisual['y'][0:pi3],
xerr=Fvisual['Slc17a7-IRES2-Cre_sem'][0:pi3],
yerr=Fvisual['y_sem'][0:pi3],color='gray',alpha=.5)
ax.plot(Fvisual['Slc17a7-IRES2-Cre'][0:pi3], Fvisual['y'][0:pi3],
ax.errorbar(Fvisual[x][0:pi3],Fvisual[y][0:pi3],
xerr=Fvisual[x+'_sem'][0:pi3],
yerr=Fvisual[y+'_sem'][0:pi3],color='gray',alpha=.5)
ax.plot(Fvisual[x][0:pi3], Fvisual[y][0:pi3],
color=colors['visual'],label='Visual',linewidth=3)

ax.set_ylabel('Vip - Sst',fontsize=16)
ax.set_xlabel('Exc',fontsize=16)

mapper = {
'Slc17a7-IRES2-Cre':'Exc',
'Sst-IRES-Cre':'Sst',
'Vip-IRES-Cre':'Vip',
'y':'Vip - Sst',
}

ax.set_ylabel(mapper[y],fontsize=16)
ax.set_xlabel(mapper[x],fontsize=16)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_tick_params(labelsize=12)
Expand All @@ -357,6 +348,7 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True):
ax.legend()
ax.axhline(0,color='k',linestyle='--',alpha=.25)
ax.axvline(0,color='k',linestyle='--',alpha=.25)
plt.tight_layout()

def get_perturbation(weights_df, run_params, kernel,experience_level="Familiar"):
visual = get_kernel_averages(weights_df.query('visual_strategy_session'),
Expand All @@ -381,7 +373,17 @@ def plot_multiple(weights_df, run_params,savefig=False):
return

def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar",
savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False):
savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False,
x = 'Slc17a7-IRES2-Cre',y='Vip-IRES-Cre'):

limit_list = {
'Slc17a7-IRES2-Cre':[-.0005,0.002],
'Sst-IRES-Cre':[-0.006,0.0125],
'Vip-IRES-Cre':[-0.005,0.025],
'y':[0,0.01],
}
lims=[limit_list[x], limit_list[y]]

visual, timing = get_perturbation(weights_df, run_params,kernel,experience_level)
time = visual['time']
offset = 0
Expand All @@ -395,48 +397,54 @@ def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar"
pi3 = len(time)

if show_steps:
demonstrate_iterative_ch(visual,kernel)
demonstrate_iterative_ch(visual,kernel,x=x,y=y)

colors = gvt.project_colors()
if ax is None:
fig, ax = plt.subplots(1,1,sharey=True,sharex=True,figsize=(4.5,4))

df = get_error(visual).loc[0:pi3]
df = get_error(visual,x=x,y=y).loc[0:pi3]
plot_iterative_ch(ax,df,'lightgray')

df = get_error(timing).loc[0:pi3]
df = get_error(timing,x=x,y=y).loc[0:pi3]
plot_iterative_ch(ax,df,'lightgray')

ax.plot(visual['Slc17a7-IRES2-Cre'][0:pi3], visual['y'][0:pi3],
ax.plot(visual[x][0:pi3], visual[y][0:pi3],
color=colors['visual'],label='Visual',linewidth=3)
ax.plot(timing['Slc17a7-IRES2-Cre'][0:pi3], timing['y'][0:pi3],
ax.plot(timing[x][0:pi3], timing[y][0:pi3],
color=colors['timing'],label='Timing',linewidth=3)
if kernel =='omissions':
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'co')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'co')
ax.plot(visual[x][0],visual[y][0],'co')
ax.plot(timing[x][0],timing[y][0],'co')
elif kernel == 'hits':
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ro')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ro')
ax.plot(visual[x][0],visual[y][0],'ro')
ax.plot(timing[x][0],timing[y][0],'ro')
elif kernel == 'misses':
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'rx')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'rx')
ax.plot(visual[x][0],visual[y][0],'rx')
ax.plot(timing[x][0],timing[y][0],'rx')
else:
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ko')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ko')
ax.plot(visual[x][0],visual[y][0],'ko')
ax.plot(timing[x][0],timing[y][0],'ko')
if multiimage:
ax.plot(visual['Slc17a7-IRES2-Cre'][pi],visual['y'][pi],'ko')
ax.plot(timing['Slc17a7-IRES2-Cre'][pi],timing['y'][pi],'ko')
ax.plot(visual['Slc17a7-IRES2-Cre'][pi2],visual['y'][pi2],'o',color='gray')
ax.plot(timing['Slc17a7-IRES2-Cre'][pi2],timing['y'][pi2],'o',color='gray')

ax.plot(visual[x][pi],visual[y][pi],'ko')
ax.plot(timing[x][pi],timing[y][pi],'ko')
ax.plot(visual[x][pi2],visual[y][pi2],'o',color='gray')
ax.plot(timing[x][pi2],timing[y][pi2],'o',color='gray')

mapper = {
'Slc17a7-IRES2-Cre':'Exc',
'Sst-IRES-Cre':'Sst',
'Vip-IRES-Cre':'Vip',
'y':'Vip - Sst',
}
if col1:
ax.set_ylabel(experience_level+'\nVip - Sst',fontsize=16)
ax.set_ylabel(experience_level+'\n'+mapper[y],fontsize=16)
elif not multi:
ax.set_ylabel('Vip - Sst',fontsize=16)
ax.set_ylabel(mapper[y],fontsize=16)
if row1:
ax.set_xlabel(kernel+'\nExc',fontsize=16)
ax.set_xlabel(kernel+'\n'+mapper[x],fontsize=16)
elif not multi:
ax.set_xlabel('Exc',fontsize=16)
ax.set_xlabel(mapper[x],fontsize=16)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xticks([0,.001, .002])
Expand All @@ -454,7 +462,7 @@ def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar"

if savefig:
filepath = run_params['figure_dir']+\
'/strategy/'+kernel+'_strategy_perturbation_{}.svg'.format(experience_level)
'/strategy/'+kernel+'_strategy_perturbation_{}_{}_{}.svg'.format(experience_level,x,y)
print('Figure saved to: '+filepath)
plt.savefig(filepath)

Expand Down
25 changes: 18 additions & 7 deletions visual_behavior_glm/GLM_strategy_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,29 @@ def kernels_by_cre(weights_beh, run_params, kernel='omissions',
savefig=False, sharey=False, depth_filter=[0,1000]):

cres = ['Vip-IRES-Cre','Sst-IRES-Cre','Slc17a7-IRES2-Cre']
limit_list = {
'Slc17a7-IRES2-Cre':[-.0005,0.002],
'Sst-IRES-Cre':[-.006,0.0125],
'Vip-IRES-Cre':[-0.005,0.025],
}
for dex, cre in enumerate(cres):
height = 4
width=8
pre_horz_offset = 1.5
post_horz_offset = 2.5
vertical_offset = .75
fig = plt.figure(figsize=(width,height))
if kernel == 'all_images':
h = [Size.Fixed(pre_horz_offset),\
Size.Fixed((width-pre_horz_offset-post_horz_offset)/3*.75)]
else:
h = [Size.Fixed(pre_horz_offset),\
Size.Fixed(width-pre_horz_offset-post_horz_offset)]

duration = run_params['kernels'][kernel]['length']
h = [Size.Fixed(pre_horz_offset),\
Size.Fixed((width-pre_horz_offset-post_horz_offset)/3*duration)]

#if kernel == 'all-images':
# h = [Size.Fixed(pre_horz_offset),\
# Size.Fixed((width-pre_horz_offset-post_horz_offset)/3*.75)]
#else:
# h = [Size.Fixed(pre_horz_offset),\
# Size.Fixed(width-pre_horz_offset-post_horz_offset)]
v = [Size.Fixed(vertical_offset),Size.Fixed(height-vertical_offset-.5)]
divider = Divider(fig, (0,0,1,1),h,v,aspect=False)
ax = fig.add_axes(divider.get_position(),\
Expand All @@ -124,7 +134,8 @@ def kernels_by_cre(weights_beh, run_params, kernel='omissions',
ax.set_title(string_mapper(cre),fontsize=16)
ax.set_ylabel(kernel+' weights\n(Ca$^{2+}$ events)',fontsize=16)

ylim=ax.get_ylim()
#ylim=ax.get_ylim()
ylim = limit_list[cre]
ax.set_ylim(ylim)
if kernel =='omissions':
out[2].plot(0,ylim[0],'co',zorder=10,clip_on=False)
Expand Down
24 changes: 23 additions & 1 deletion visual_behavior_glm/strategy_paper_figure_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import visual_behavior_glm.PSTH as psth
import visual_behavior_glm.GLM_params as glm_params
import visual_behavior_glm.GLM_fit_tools as gft
import visual_behavior_glm.GLM_fit_dev as gfd
import visual_behavior_glm.GLM_schematic_plots as gsm
import visual_behavior_glm.GLM_perturbation_tools as gpt
import visual_behavior_glm.GLM_strategy_tools as gst
import matplotlib.pyplot as plt
plt.ion()
from importlib import reload
Expand Down Expand Up @@ -135,8 +137,28 @@

## Fig S5 - GLM Supplement
################################################################################
GLM_VERSION = '24_events_all_L2_optimize_by_session'
run_params, results, results_pivoted, weights_df = gfd.get_analysis_dfs(GLM_VERSION)
weights_beh = gst.add_behavior_session_metrics(weights_df, summary_df)

# Plot kernels over time, compare cell types
gpt.analysis(weights_beh, run_params, 'all-images')
gpt.analysis(weights_beh, run_params, 'omissions')
gst.kernels_by_cre(weights_beh, run_params)
gpt.analysis(weights_beh, run_params, 'hits')
gpt.analysis(weights_beh, run_params, 'misses')

# Plot kernels over time, compare strategies
gst.kernels_by_cre(weights_beh, run_params, 'all-images')
gst.kernels_by_cre(weights_beh, run_params, 'omissions')
gst.kernels_by_cre(weights_beh, run_params, 'hits')
gst.kernels_by_cre(weights_beh, run_params, 'misses')

# Plot state space plots
gpt.plot_perturbation(weights_beh, run_params, 'all-images')
gpt.plot_perturbation(weights_beh, run_params, 'omissions')
gpt.plot_perturbation(weights_beh, run_params, 'hits')
gpt.plot_perturbation(weights_beh, run_params, 'misses')



## Fig. 6 Engagement PSTHs
Expand Down

0 comments on commit 9c1c256

Please sign in to comment.