Skip to content

Commit

Permalink
letting 2d plots by flexible by cell type and adding ellipses to errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpiet committed Feb 23, 2023
1 parent e1b5759 commit 29d5305
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 36 deletions.
119 changes: 83 additions & 36 deletions visual_behavior_glm/GLM_perturbation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,20 +277,52 @@ def plot_iterative_ch(ax,df,color,show_steps=False):
else:
ax.fill(points[hull.vertices,0],points[hull.vertices,1],color=color)

def get_points(df):
def get_points(df,ellipse=True):
'''
Return a 2D array of error points for two subsequent points in df
'''
x=[]
y=[]
for index, row in df.iterrows():
this_x= [row.x1,row.x,row.x2,row.x]
this_y= [row.y,row.y2,row.y,row.y1]
if ellipse:
r1 = (row.x2-row.x1)/2
r2 = (row.y2-row.y1)/2
n=20
this_x = [
r1*np.cos(theta)+row.x
for theta in (np.pi*2 * i/n for i in range(n))
]
this_y = [
r2*np.sin(theta)+row.y
for theta in (np.pi*2 * i/n for i in range(n))
]
else:
this_x= [row.x1,row.x,row.x2,row.x]
this_y= [row.y,row.y2,row.y,row.y1]
x=x+this_x
y=y+this_y
return np.array([x,y]).T

def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True):
def get_ellipse_points(df):
plt.figure()
for index, row in df.iterrows():
plt.plot(row.x,row.y,'ro')
plt.plot([row.x1,row.x2],[row.y,row.y],'r-')
plt.plot([row.x,row.x],[row.y1,row.y2],'r-')
r1 = (row.x2-row.x1)/2
r2 = (row.y2-row.y1)/2
n=20
xpoints = [
r1*np.cos(theta)+row.x
for theta in (np.pi*2 * i/n for i in range(n))
]
ypoints = [
r2*np.sin(theta)+row.y
for theta in (np.pi*2 * i/n for i in range(n))
]
plt.plot(xpoints, ypoints, 'b-')
def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True,
x='Slc17a7-IRES2-Cre',y='y'):
'''
A demonstration that shows the iterative convex hull solution
'''
Expand All @@ -306,19 +338,26 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True):
pi3 = len(time)

colors = gvt.project_colors()
fig, ax = plt.subplots(figsize=(4,3.5))
df = get_error(Fvisual)
fig, ax = plt.subplots(figsize=(5,3.5))
df = get_error(Fvisual,x=x,y=y)
df = df.loc[0:pi3]
plot_iterative_ch(ax,df,'lightgray',show_steps=show_steps)
if show_steps:
ax.errorbar(Fvisual['Slc17a7-IRES2-Cre'][0:pi3],Fvisual['y'][0:pi3],
xerr=Fvisual['Slc17a7-IRES2-Cre_sem'][0:pi3],
yerr=Fvisual['y_sem'][0:pi3],color='gray',alpha=.5)
ax.plot(Fvisual['Slc17a7-IRES2-Cre'][0:pi3], Fvisual['y'][0:pi3],
ax.errorbar(Fvisual[x][0:pi3],Fvisual[y][0:pi3],
xerr=Fvisual[x+'_sem'][0:pi3],
yerr=Fvisual[y+'_sem'][0:pi3],color='gray',alpha=.5)
ax.plot(Fvisual[x][0:pi3], Fvisual[y][0:pi3],
color=colors['visual'],label='Visual',linewidth=3)

ax.set_ylabel('Vip - Sst',fontsize=16)
ax.set_xlabel('Exc',fontsize=16)

mapper = {
'Slc17a7-IRES2-Cre':'Exc',
'Sst-IRES-Cre':'Sst',
'Vip-IRES-Cre':'Vip',
'y':'Vip - Sst',
}

ax.set_ylabel(mapper[y],fontsize=16)
ax.set_xlabel(mapper[x],fontsize=16)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_tick_params(labelsize=12)
Expand All @@ -327,6 +366,7 @@ def demonstrate_iterative_ch(Fvisual,kernel='omissions',show_steps=True):
ax.legend()
ax.axhline(0,color='k',linestyle='--',alpha=.25)
ax.axvline(0,color='k',linestyle='--',alpha=.25)
plt.tight_layout()

def get_perturbation(weights_df, run_params, kernel,experience_level="Familiar"):
visual = get_kernel_averages(weights_df.query('visual_strategy_session'),
Expand All @@ -351,7 +391,8 @@ def plot_multiple(weights_df, run_params,savefig=False):
return

def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar",
savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False):
savefig=False,lims = None,show_steps=False,ax=None,col1=False,row1=False,multi=False,
x = 'Slc17a7-IRES2-Cre',y='Vip-IRES-Cre'):
visual, timing = get_perturbation(weights_df, run_params,kernel,experience_level)
time = visual['time']
offset = 0
Expand All @@ -365,48 +406,54 @@ def plot_perturbation(weights_df, run_params, kernel,experience_level="Familiar"
pi3 = len(time)

if show_steps:
demonstrate_iterative_ch(visual,kernel)
demonstrate_iterative_ch(visual,kernel,x=x,y=y)

colors = gvt.project_colors()
if ax is None:
fig, ax = plt.subplots(1,1,sharey=True,sharex=True,figsize=(4.5,4))

df = get_error(visual).loc[0:pi3]
df = get_error(visual,x=x,y=y).loc[0:pi3]
plot_iterative_ch(ax,df,'lightgray')

df = get_error(timing).loc[0:pi3]
df = get_error(timing,x=x,y=y).loc[0:pi3]
plot_iterative_ch(ax,df,'lightgray')

ax.plot(visual['Slc17a7-IRES2-Cre'][0:pi3], visual['y'][0:pi3],
ax.plot(visual[x][0:pi3], visual[y][0:pi3],
color=colors['visual'],label='Visual',linewidth=3)
ax.plot(timing['Slc17a7-IRES2-Cre'][0:pi3], timing['y'][0:pi3],
ax.plot(timing[x][0:pi3], timing[y][0:pi3],
color=colors['timing'],label='Timing',linewidth=3)
if kernel =='omissions':
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'co')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'co')
ax.plot(visual[x][0],visual[y][0],'co')
ax.plot(timing[x][0],timing[y][0],'co')
elif kernel == 'hits':
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ro')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ro')
ax.plot(visual[x][0],visual[y][0],'ro')
ax.plot(timing[x][0],timing[y][0],'ro')
elif kernel == 'misses':
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'rx')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'rx')
ax.plot(visual[x][0],visual[y][0],'rx')
ax.plot(timing[x][0],timing[y][0],'rx')
else:
ax.plot(visual['Slc17a7-IRES2-Cre'][0],visual['y'][0],'ko')
ax.plot(timing['Slc17a7-IRES2-Cre'][0],timing['y'][0],'ko')
ax.plot(visual[x][0],visual[y][0],'ko')
ax.plot(timing[x][0],timing[y][0],'ko')
if multiimage:
ax.plot(visual['Slc17a7-IRES2-Cre'][pi],visual['y'][pi],'ko')
ax.plot(timing['Slc17a7-IRES2-Cre'][pi],timing['y'][pi],'ko')
ax.plot(visual['Slc17a7-IRES2-Cre'][pi2],visual['y'][pi2],'o',color='gray')
ax.plot(timing['Slc17a7-IRES2-Cre'][pi2],timing['y'][pi2],'o',color='gray')

ax.plot(visual[x][pi],visual[y][pi],'ko')
ax.plot(timing[x][pi],timing[y][pi],'ko')
ax.plot(visual[x][pi2],visual[y][pi2],'o',color='gray')
ax.plot(timing[x][pi2],timing[y][pi2],'o',color='gray')

mapper = {
'Slc17a7-IRES2-Cre':'Exc',
'Sst-IRES-Cre':'Sst',
'Vip-IRES-Cre':'Vip',
'y':'Vip - Sst',
}
if col1:
ax.set_ylabel(experience_level+'\nVip - Sst',fontsize=16)
ax.set_ylabel(experience_level+'\n'+mapper[y],fontsize=16)
elif not multi:
ax.set_ylabel('Vip - Sst',fontsize=16)
ax.set_ylabel(mapper[y],fontsize=16)
if row1:
ax.set_xlabel(kernel+'\nExc',fontsize=16)
ax.set_xlabel(kernel+'\n'+mapper[x],fontsize=16)
elif not multi:
ax.set_xlabel('Exc',fontsize=16)
ax.set_xlabel(mapper[x],fontsize=16)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xticks([0,.001, .002])
Expand Down
8 changes: 8 additions & 0 deletions visual_behavior_glm/strategy_paper_figure_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@
gst.kernels_by_cre(weights_beh, run_params, 'hits')
gst.kernels_by_cre(weights_beh, run_params, 'misses')

# Plot state space plots
gpt.plot_perturbation(weights_beh, run_params. 'all-images')
gpt.plot_perturbation(weights_beh, run_params, 'omissions')
gpt.plot_perturbation(weights_beh, run_params, 'hits')
gpt.plot_perturbation(weights_beh, run_params, 'misses')



## Fig. 6 Engagement PSTHs
################################################################################
dfs = psth.get_figure_4_psth(data='events')
Expand Down

0 comments on commit 29d5305

Please sign in to comment.