Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rewriter] Fixing some issues with StrInCol #20

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions dias/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,25 +1025,26 @@ def rewrite_ast(cell_ast: ast.Module) -> Tuple[str, Dict]:


# Get string versions
the_str = str_in_col.the_str.value
the_str = astor.to_source(str_in_col.the_str)
assert isinstance(str_in_col.the_sub.value, ast.Name)
df = str_in_col.the_sub.value.id
col = None
if isinstance(str_in_col.the_sub.slice, ast.Name):
col = str_in_col.the_sub.slice.id
the_sub = f"{df}[{col}]"
else:
assert isinstance(str_in_col.the_sub.slice, ast.Constant)
col = str_in_col.the_sub.slice.value
the_sub = f"{df}[{col}]"
orig = f"'{the_str}' in {the_sub}.to_string()"
col = astor.to_source(str_in_col.the_sub.slice)
the_sub = f"{df}[{col}]"
orig = f"{the_str} in {the_sub}.to_string()"

# You need to be careful when handling strings like that because you might
# miss parentheses.

# We can specialize this for an index that is an int. Try to convert the string to int
# and if you fail,
contains_expr = f"astype(str).str.contains('{the_str}').any()"
new_expr = f"({the_sub}.{contains_expr} or _REWR_index_contains({the_sub}.index, '{the_str}')) if type({df}) == pd.DataFrame else ({orig})"
contains_expr = f"astype(str).str.contains({the_str}).any()"
new_expr = f"({the_sub}.{contains_expr} or _REWR_index_contains({the_sub}.index, {the_str})) if (type({df}) == pd.DataFrame and {the_sub}.index.dtype == np.int64) else ({orig})"
str_in_col.cmp_encl.set_enclosed_obj(ast.parse(new_expr, mode='eval'))
### END OF LOOP ###

Expand Down
105 changes: 105 additions & 0 deletions tests/str_in_col-index.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import dias.rewriter"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('./datasets/dataranch__supermarket-sales-prediction-xgboost-fastai__SampleSuperstore.csv')\n",
"df_multi_index = df.set_index(['Postal Code', 'Sub-Category'])"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "Setting a MultiIndex dtype to anything other than object is not supported",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/var/folders/sh/36vdgc9s2vqgz2sjl830ryvm0000gn/T/ipykernel_86399/1440505495.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdf_multi_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m if (df_multi_index[col].astype(str).str.contains('ou').any() or\n\u001b[0;32m---> 13\u001b[0;31m _REWR_index_contains(df_multi_index[col].index, 'ou') if type(\n\u001b[0m\u001b[1;32m 14\u001b[0m df_multi_index) == pd.DataFrame else 'ou' in df_multi_index[col].\n\u001b[1;32m 15\u001b[0m to_string()):\n",
"\u001b[0;32m/var/folders/sh/36vdgc9s2vqgz2sjl830ryvm0000gn/T/ipykernel_86399/1440505495.py\u001b[0m in \u001b[0;36m_REWR_index_contains\u001b[0;34m(index, s)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontains\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdf_multi_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m if (df_multi_index[col].astype(str).str.contains('ou').any() or\n",
"\u001b[0;32m~/opt/anaconda3/lib/python3.9/site-packages/pandas/core/indexes/multi.py\u001b[0m in \u001b[0;36mastype\u001b[0;34m(self, dtype, copy)\u001b[0m\n\u001b[1;32m 3754\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3755\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_object_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3756\u001b[0;31m raise TypeError(\n\u001b[0m\u001b[1;32m 3757\u001b[0m \u001b[0;34m\"Setting a MultiIndex dtype to anything other than object \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3758\u001b[0m \u001b[0;34m\"is not supported\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: Setting a MultiIndex dtype to anything other than object is not supported"
]
}
],
"source": [
"our = []\n",
"for col in df_multi_index.columns:\n",
" if 'ou' in df_multi_index[col].to_string():\n",
" our.append(col)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"# DIAS_DISABLE\n",
"defa = []\n",
"for col in df_multi_index.columns:\n",
" if 'ou' in df_multi_index[col].to_string():\n",
" defa.append(col)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([], ['City', 'State', 'Region'])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"assert our == defa"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
24 changes: 24 additions & 0 deletions tests/str_in_col-index.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"cells": [
{
"raw": "\ndf = pd.read_csv('./datasets/dataranch__supermarket-sales-prediction-xgboost-fastai__SampleSuperstore.csv')\ndf_multi_index = df.set_index(['Postal Code', 'Sub-Category'])\n",
"modified": "df = pd.read_csv(\n './datasets/dataranch__supermarket-sales-prediction-xgboost-fastai__SampleSuperstore.csv'\n )\ndf_multi_index = df.set_index(['Postal Code', 'Sub-Category'])\n",
"patts-hit": {},
"rewritten-exec-time": 22.590208
},
{
"raw": "\nour = []\nfor col in df_multi_index.columns:\n if 'ou' in df_multi_index[col].to_string():\n our.append(col)\n",
"modified": "our = []\ndef _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\nfor col in df_multi_index.columns:\n if (df_multi_index[col].astype(str).str.contains('ou').any() or\n _REWR_index_contains(df_multi_index[col].index, 'ou') if type(\n df_multi_index) == pd.DataFrame and df_multi_index[col].index.dtype ==\n np.int64 else 'ou' in df_multi_index[col].to_string()):\n our.append(col)\n",
"patts-hit": {
"MultipleStrInCol": 1
},
"rewritten-exec-time": 932.628458
},
{
"raw": "\nassert our == defa\n",
"modified": "assert our == defa\n",
"patts-hit": {},
"rewritten-exec-time": 0.211667
}
]
}
11 changes: 10 additions & 1 deletion tests/str_in_col.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@
" if '%' in df[col].to_string() or ',' in df[col].to_string():\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"'41' in df['Profit'].to_string()"
]
}
],
"metadata": {
Expand All @@ -112,7 +121,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion tests/str_in_col.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@
},
{
"raw":"\nfor col in df.columns:\n if '%' in df[col].to_string() or ',' in df[col].to_string():\n pass\n",
"modified":"def _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\nfor col in df.columns:\n if (df[col].astype(str).str.contains('%').any() or _REWR_index_contains\n (df[col].index, '%') if type(df) == pd.DataFrame else '%' in df[col\n ].to_string()) or (df[col].astype(str).str.contains(',').any() or\n _REWR_index_contains(df[col].index, ',') if type(df) == pd.\n DataFrame else ',' in df[col].to_string()):\n pass\n",
"modified":"def _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\nfor col in df.columns:\n if (df[col].astype(str).str.contains('%').any() or _REWR_index_contains\n (df[col].index, '%') if type(df) == pd.DataFrame and df[col].index.\n dtype == np.int64 else '%' in df[col].to_string()) or (df[col].\n astype(str).str.contains(',').any() or _REWR_index_contains(df[col]\n .index, ',') if type(df) == pd.DataFrame and df[col].index.dtype ==\n np.int64 else ',' in df[col].to_string()):\n pass\n",
"patts-hit":{
"MultipleStrInCol":1
},
"rewritten-exec-time":68.81415
},
{
"raw": "\n'41' in df['Profit'].to_string()\n",
"modified": "def _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\n(df['Profit'].astype(str).str.contains('41').any() or _REWR_index_contains(\n df['Profit'].index, '41') if type(df) == pd.DataFrame and df['Profit'].\n index.dtype == np.int64 else '41' in df['Profit'].to_string())\n",
"patts-hit": {
"MultipleStrInCol": 1
},
"rewritten-exec-time": 129.444458
}
]
}