Skip to content

Commit

Permalink
a*pa-evals: show red dots for bad_k everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
RagnarGrootKoerkamp committed Aug 4, 2023
1 parent b64773b commit 5440ed1
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions evals/astarpa/evals.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,28 @@
" if 'length' in df.columns and 'output_Ok_stats_expanded' in df.columns:\n",
" df['band'] = df['output_Ok_stats_expanded'] / (df['stats_seqpairs']* df['length'])\n",
"\n",
" def bad_k(row):\n",
" # Bad k only makes sense for A*PA with SH-based heuristic.\n",
" if row['algo_name'] != 'AstarPA':\n",
" return False\n",
" \n",
" h = row['job_algo_AstarPA_heuristic_type'].lower()\n",
" if h not in ['sh', 'csh', 'gcsh']:\n",
" return False\n",
" \n",
" # Too small k => too many matches\n",
" if row.k < math.log(row.length, 4) + (2 if row.r == 2 else 0):\n",
" return True\n",
" \n",
" # Too large k => not enough potential\n",
" if row.k > row.r/row.divergence:\n",
" return True\n",
" \n",
" # all is fine\n",
" return False\n",
"\n",
" if 'length' in df.columns:\n",
" df['bad_k'] = df.apply(lambda row: row.k < math.log(row.length, 4) + (2 if row.r == 2 else 0) or row.k > row.r/row.divergence, axis=1)\n",
" #else:\n",
" #df['stats_'] = df.apply(lambda row: row.k < math.log(row.length, 4) + (2 if row.r == 2 else 0) or row.k > row.r/row.divergence, axis=1)\n",
" df['bad_k'] = df.apply(bad_k, axis=1)\n",
"\n",
" def runtime_capped(row):\n",
" if not math.isnan(row['runtime']):\n",
Expand Down Expand Up @@ -311,7 +328,7 @@
" width=None,\n",
" height=None,\n",
" default_r=None,\n",
" bad_k=False,\n",
" bad_k=True,\n",
" ):\n",
" \n",
" if df[y].isna().all():\n",
Expand Down Expand Up @@ -971,7 +988,7 @@
"df['rn'] = df.apply(lambda row: f\"{row['r']}-{row['length']}\", axis=1)\n",
"for d, g in df.groupby('errorrate'):\n",
" plot(g, file=f'params_e{d}', x='k', y='s_per_pair', xlog=False, ylog=True, connect=True, line_labels=True,\n",
" hue='length', style='r', xlim=(4, 26), bad_k = True,\n",
" hue='length', style='r', xlim=(4, 26),\n",
" width=1.3*4.4, height=1.3*3)\n",
"plt.show()\n",
"\n",
Expand All @@ -995,7 +1012,7 @@
" subfig.suptitle(dataset_pretty[k], y=0)\n",
" axs = subfig.subplots(1, 2, sharey=True)\n",
" for (r, g), ax in zip(g.groupby(['r'], sort=False), axs): \n",
" plot(g, x='k', y='runtime_capped', hue='r', xlog=False, ylog=True, categorical=True, ylim=(0.3, 150), bad_k = True, ax=ax)\n",
" plot(g, x='k', y='runtime_capped', hue='r', xlog=False, ylog=True, categorical=True, ylim=(0.3, 150), ax=ax)\n",
" ax.set_xlabel(f'$k$, {\"in\" if r==2 else \"\"}exact matches ($r={r}$)')\n",
"\n",
" plt.savefig(f\"plots/{file}.pdf\", dpi=300, bbox_inches='tight')\n",
Expand Down

0 comments on commit 5440ed1

Please sign in to comment.