diff --git a/evals/astarpa/evals.ipynb b/evals/astarpa/evals.ipynb index 236fb1d..e2e11ef 100644 --- a/evals/astarpa/evals.ipynb +++ b/evals/astarpa/evals.ipynb @@ -241,11 +241,28 @@ " if 'length' in df.columns and 'output_Ok_stats_expanded' in df.columns:\n", " df['band'] = df['output_Ok_stats_expanded'] / (df['stats_seqpairs']* df['length'])\n", "\n", + " def bad_k(row):\n", + " # Bad k only makes sense for A*PA with SH-based heuristic.\n", + " if row['algo_name'] != 'AstarPA':\n", + " return False\n", + " \n", + " h = row['job_algo_AstarPA_heuristic_type'].lower()\n", + " if h not in ['sh', 'csh', 'gcsh']:\n", + " return False\n", + " \n", + " # Too small k => too many matches\n", + " if row.k < math.log(row.length, 4) + (2 if row.r == 2 else 0):\n", + " return True\n", + " \n", + " # Too large k => not enough potential\n", + " if row.k > row.r/row.divergence:\n", + " return True\n", + " \n", + " # all is fine\n", + " return False\n", "\n", " if 'length' in df.columns:\n", - " df['bad_k'] = df.apply(lambda row: row.k < math.log(row.length, 4) + (2 if row.r == 2 else 0) or row.k > row.r/row.divergence, axis=1)\n", - " #else:\n", - " #df['stats_'] = df.apply(lambda row: row.k < math.log(row.length, 4) + (2 if row.r == 2 else 0) or row.k > row.r/row.divergence, axis=1)\n", + " df['bad_k'] = df.apply(bad_k, axis=1)\n", "\n", " def runtime_capped(row):\n", " if not math.isnan(row['runtime']):\n", @@ -311,7 +328,7 @@ " width=None,\n", " height=None,\n", " default_r=None,\n", - " bad_k=False,\n", + " bad_k=True,\n", " ):\n", " \n", " if df[y].isna().all():\n", @@ -971,7 +988,7 @@ "df['rn'] = df.apply(lambda row: f\"{row['r']}-{row['length']}\", axis=1)\n", "for d, g in df.groupby('errorrate'):\n", " plot(g, file=f'params_e{d}', x='k', y='s_per_pair', xlog=False, ylog=True, connect=True, line_labels=True,\n", - " hue='length', style='r', xlim=(4, 26), bad_k = True,\n", + " hue='length', style='r', xlim=(4, 26),\n", " width=1.3*4.4, height=1.3*3)\n", "plt.show()\n", "\n", @@ -995,7 +1012,7 @@ " subfig.suptitle(dataset_pretty[k], y=0)\n", " axs = subfig.subplots(1, 2, sharey=True)\n", " for (r, g), ax in zip(g.groupby(['r'], sort=False), axs): \n", - " plot(g, x='k', y='runtime_capped', hue='r', xlog=False, ylog=True, categorical=True, ylim=(0.3, 150), bad_k = True, ax=ax)\n", + " plot(g, x='k', y='runtime_capped', hue='r', xlog=False, ylog=True, categorical=True, ylim=(0.3, 150), ax=ax)\n", " ax.set_xlabel(f'$k$, {\"in\" if r==2 else \"\"}exact matches ($r={r}$)')\n", "\n", " plt.savefig(f\"plots/{file}.pdf\", dpi=300, bbox_inches='tight')\n",