Skip to content

Commit

Permalink
Clean type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
franckalbinet committed Oct 3, 2024
1 parent 9745292 commit 1c32153
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 54 deletions.
2 changes: 1 addition & 1 deletion marisco/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.0"
__version__ = "0.3.0"
64 changes: 39 additions & 25 deletions marisco/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pandas as pd
from .configs import cfg, cdl_cfg, grp_names
from functools import partial
from typing import List, Dict, Callable, Tuple
from typing import List, Dict, Callable, Tuple, Any, Optional
from pathlib import Path

from .configs import get_lut, nuc_lut_path
Expand All @@ -28,7 +28,10 @@ class Callback():
order = 0

# %% ../nbs/api/callbacks.ipynb 7
def run_cbs(cbs, obj=None):
def run_cbs(
cbs: List[Callback], # List of callbacks to run
obj: Any # Object to pass to the callbacks
):
"Run the callbacks in the order they are specified."
for cb in sorted(cbs, key=attrgetter('order')):
if cb.__doc__: obj.logs.append(cb.__doc__)
Expand All @@ -37,16 +40,16 @@ def run_cbs(cbs, obj=None):
# %% ../nbs/api/callbacks.ipynb 8
class Transformer():
def __init__(self,
dfs:pd.DataFrame, # Dictionary of DataFrames to transform
cbs:list=None, # List of callbacks to run
inplace:bool=False # Whether to modify the dataframes in place
dfs: Dict[str, pd.DataFrame], # Dictionary of DataFrames to transform
cbs: Optional[List[Callback]]=None, # List of callbacks to run
inplace: bool=False # Whether to modify the dataframes in place
):
"Transform the dataframes according to the specified callbacks."
fc.store_attr()
self.dfs = dfs if inplace else {k: v.copy() for k, v in dfs.items()}
self.logs = []

def unique(self, col_name):
def unique(self, col_name: str) -> np.ndarray:
"Distinct values of a specific column present in all groups."
columns = [df.get(col_name) for df in self.dfs.values() if df.get(col_name) is not None]
values = np.concatenate(columns) if columns else []
Expand All @@ -60,8 +63,8 @@ def __call__(self):
# %% ../nbs/api/callbacks.ipynb 15
class SanitizeLonLatCB(Callback):
"Drop row when both longitude & latitude equal 0. Drop unrealistic longitude & latitude values. Convert longitude & latitude `,` separator to `.` separator."
def __init__(self, verbose=False): fc.store_attr()
def __call__(self, tfm):
def __init__(self, verbose: bool=False): fc.store_attr()
def __call__(self, tfm: Transformer):
for grp, df in tfm.dfs.items():
" Convert `,` separator to `.` separator"
df['lon'] = [float(str(x).replace(',', '.')) for x in df['lon']]
Expand All @@ -84,8 +87,8 @@ def __call__(self, tfm):
# %% ../nbs/api/callbacks.ipynb 20
class AddSampleTypeIdColumnCB(Callback):
def __init__(self,
cdl_cfg:Callable=cdl_cfg, # Callable to get the CDL config dictionary
col_name:str='samptype_id'
cdl_cfg: Callable=cdl_cfg, # Callable to get the CDL config dictionary
col_name: str='samptype_id' # Column name to store the sample type id
):
"Add a column with the sample type id as defined in the CDL."
fc.store_attr()
Expand All @@ -97,16 +100,16 @@ def __call__(self, tfm):
# %% ../nbs/api/callbacks.ipynb 23
class AddNuclideIdColumnCB(Callback):
def __init__(self,
col_value:str, # Column name containing the nuclide name
lut_fname_fn:callable=nuc_lut_path, # Function returning the lut path
col_name:str='nuclide_id' # Column name to store the nuclide id
col_value: str, # Column name containing the nuclide name
lut_fname_fn: Callable=nuc_lut_path, # Function returning the lut path
col_name: str='nuclide_id' # Column name to store the nuclide id
):
"Add a column with the nuclide id."
fc.store_attr()
self.lut = get_lut(lut_fname_fn().parent, lut_fname_fn().name,
key='nc_name', value='nuclide_id', reverse=False)

def __call__(self, tfm):
def __call__(self, tfm: Transformer):
for grp, df in tfm.dfs.items():
df[self.col_name] = df[self.col_value].map(self.lut)

Expand Down Expand Up @@ -148,9 +151,9 @@ def _remap_value(self, value: str) -> Any:
class LowerStripNameCB(Callback):
"Convert values to lowercase and strip any trailing spaces."
def __init__(self,
col_src:str, # Source column name e.g. 'Nuclide'
col_dst:str=None, # Destination column name
fn_transform:Callable=lambda x: x.lower().strip() # Transformation function
col_src: str, # Source column name e.g. 'Nuclide'
col_dst: str=None, # Destination column name
fn_transform: Callable=lambda x: x.lower().strip() # Transformation function
):
fc.store_attr()
self.__doc__ = f"Convert values from '{col_src}' to lowercase, strip spaces, and store in '{col_dst}'."
Expand All @@ -168,7 +171,7 @@ def __call__(self, tfm):
class RemoveAllNAValuesCB(Callback):
"Remove rows with all NA values."
def __init__(self,
cols_to_check:dict # A dictionary with the sample type as key and the column name to check as value
cols_to_check: Dict[str, str] # A dictionary with the sample type as key and the column name to check as value
):
fc.store_attr()

Expand All @@ -180,9 +183,13 @@ def __call__(self, tfm):

# %% ../nbs/api/callbacks.ipynb 34
class ReshapeLongToWide(Callback):
def __init__(self, columns=['nuclide'], values=['value'],
num_fill_value=-999, str_fill_value='STR FILL VALUE'):
"Convert data from long to wide with renamed columns."
"Convert data from long to wide with renamed columns."
def __init__(self,
columns: List[str]=['nuclide'], # Columns to use as index
values: List[str]=['value'], # Columns to use as values
num_fill_value: int=-999, # Fill value for numeric columns
str_fill_value='STR FILL VALUE'
):
fc.store_attr()
self.derived_cols = self._get_derived_cols()

Expand Down Expand Up @@ -234,8 +241,10 @@ def __call__(self, tfm):

# %% ../nbs/api/callbacks.ipynb 36
class CompareDfsAndTfmCB(Callback):
def __init__(self, dfs: Dict[str, pd.DataFrame]):
"Create a dataframe of dropped data. Data included in the `dfs` not in the `tfm`."
"Create a dataframe of dropped data. Data included in the `dfs` not in the `tfm`."
def __init__(self,
dfs: Dict[str, pd.DataFrame] # Original dataframes
):
fc.store_attr()

def __call__(self, tfm: Transformer) -> None:
Expand Down Expand Up @@ -271,8 +280,13 @@ def _compute_stats(self,

# %% ../nbs/api/callbacks.ipynb 41
class EncodeTimeCB(Callback):
"Encode time as `int` representing seconds since xxx"
def __init__(self, cfg , verbose=False): fc.store_attr()
"Encode time as `int` representing seconds since xxx."
def __init__(self,
cfg: dict, # Configuration dictionary
verbose: bool=False # Whether to print the number of invalid time entries
):
fc.store_attr()

def __call__(self, tfm):
def format_time(x):
return date2num(x, units=self.cfg['units']['time'])
Expand Down
77 changes: 50 additions & 27 deletions nbs/api/callbacks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"import pandas as pd\n",
"from marisco.configs import cfg, cdl_cfg, grp_names\n",
"from functools import partial \n",
"from typing import List, Dict, Callable, Tuple\n",
"from typing import List, Dict, Callable, Tuple, Any, Optional\n",
"from pathlib import Path \n",
"\n",
"from marisco.configs import get_lut, nuc_lut_path\n",
Expand All @@ -50,7 +50,16 @@
"execution_count": null,
"id": "91324c91",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"#| hide\n",
"from marisco.configs import cdl_cfg, CONFIGS_CDL\n",
Expand Down Expand Up @@ -97,7 +106,10 @@
"outputs": [],
"source": [
"#| exports\n",
"def run_cbs(cbs, obj=None):\n",
"def run_cbs(\n",
" cbs: List[Callback], # List of callbacks to run\n",
" obj: Any # Object to pass to the callbacks\n",
" ):\n",
" \"Run the callbacks in the order they are specified.\"\n",
" for cb in sorted(cbs, key=attrgetter('order')):\n",
" if cb.__doc__: obj.logs.append(cb.__doc__)\n",
Expand All @@ -114,16 +126,16 @@
"#| exports\n",
"class Transformer():\n",
" def __init__(self, \n",
" dfs:pd.DataFrame, # Dictionary of DataFrames to transform\n",
" cbs:list=None, # List of callbacks to run\n",
" inplace:bool=False # Whether to modify the dataframes in place\n",
" dfs: Dict[str, pd.DataFrame], # Dictionary of DataFrames to transform\n",
" cbs: Optional[List[Callback]]=None, # List of callbacks to run\n",
" inplace: bool=False # Whether to modify the dataframes in place\n",
" ): \n",
" \"Transform the dataframes according to the specified callbacks.\"\n",
" fc.store_attr()\n",
" self.dfs = dfs if inplace else {k: v.copy() for k, v in dfs.items()}\n",
" self.logs = []\n",
" \n",
" def unique(self, col_name):\n",
" def unique(self, col_name: str) -> np.ndarray:\n",
" \"Distinct values of a specific column present in all groups.\"\n",
" columns = [df.get(col_name) for df in self.dfs.values() if df.get(col_name) is not None]\n",
" values = np.concatenate(columns) if columns else []\n",
Expand Down Expand Up @@ -153,7 +165,7 @@
"source": [
"class TestCB(Callback):\n",
" \"A test callback to add 1 to the depth.\"\n",
" def __call__(self, tfm):\n",
" def __call__(self, tfm: Transformer):\n",
" for grp, df in tfm.dfs.items(): \n",
" df['depth'] = df['depth'].apply(lambda x: x+1)"
]
Expand Down Expand Up @@ -209,8 +221,8 @@
"#| exports\n",
"class SanitizeLonLatCB(Callback):\n",
" \"Drop row when both longitude & latitude equal 0. Drop unrealistic longitude & latitude values. Convert longitude & latitude `,` separator to `.` separator.\"\n",
" def __init__(self, verbose=False): fc.store_attr()\n",
" def __call__(self, tfm):\n",
" def __init__(self, verbose: bool=False): fc.store_attr()\n",
" def __call__(self, tfm: Transformer):\n",
" for grp, df in tfm.dfs.items():\n",
" \" Convert `,` separator to `.` separator\"\n",
" df['lon'] = [float(str(x).replace(',', '.')) for x in df['lon']]\n",
Expand Down Expand Up @@ -298,8 +310,8 @@
"#| exports\n",
"class AddSampleTypeIdColumnCB(Callback):\n",
" def __init__(self, \n",
" cdl_cfg:Callable=cdl_cfg, # Callable to get the CDL config dictionary\n",
" col_name:str='samptype_id'\n",
" cdl_cfg: Callable=cdl_cfg, # Callable to get the CDL config dictionary\n",
" col_name: str='samptype_id' # Column name to store the sample type id\n",
" ): \n",
" \"Add a column with the sample type id as defined in the CDL.\"\n",
" fc.store_attr()\n",
Expand Down Expand Up @@ -343,16 +355,16 @@
"#| exports\n",
"class AddNuclideIdColumnCB(Callback):\n",
" def __init__(self, \n",
" col_value:str, # Column name containing the nuclide name\n",
" lut_fname_fn:callable=nuc_lut_path, # Function returning the lut path\n",
" col_name:str='nuclide_id' # Column name to store the nuclide id\n",
" col_value: str, # Column name containing the nuclide name\n",
" lut_fname_fn: Callable=nuc_lut_path, # Function returning the lut path\n",
" col_name: str='nuclide_id' # Column name to store the nuclide id\n",
" ): \n",
" \"Add a column with the nuclide id.\"\n",
" fc.store_attr()\n",
" self.lut = get_lut(lut_fname_fn().parent, lut_fname_fn().name, \n",
" key='nc_name', value='nuclide_id', reverse=False)\n",
" \n",
" def __call__(self, tfm):\n",
" def __call__(self, tfm: Transformer):\n",
" for grp, df in tfm.dfs.items(): \n",
" df[self.col_name] = df[self.col_value].map(self.lut)"
]
Expand Down Expand Up @@ -437,9 +449,9 @@
"class LowerStripNameCB(Callback):\n",
" \"Convert values to lowercase and strip any trailing spaces.\"\n",
" def __init__(self, \n",
" col_src:str, # Source column name e.g. 'Nuclide'\n",
" col_dst:str=None, # Destination column name\n",
" fn_transform:Callable=lambda x: x.lower().strip() # Transformation function\n",
" col_src: str, # Source column name e.g. 'Nuclide'\n",
" col_dst: str=None, # Destination column name\n",
" fn_transform: Callable=lambda x: x.lower().strip() # Transformation function\n",
" ):\n",
" fc.store_attr()\n",
" self.__doc__ = f\"Convert values from '{col_src}' to lowercase, strip spaces, and store in '{col_dst}'.\"\n",
Expand Down Expand Up @@ -510,7 +522,7 @@
"class RemoveAllNAValuesCB(Callback):\n",
" \"Remove rows with all NA values.\"\n",
" def __init__(self, \n",
" cols_to_check:dict # A dictionary with the sample type as key and the column name to check as value\n",
" cols_to_check: Dict[str, str] # A dictionary with the sample type as key and the column name to check as value\n",
" ):\n",
" fc.store_attr()\n",
"\n",
Expand Down Expand Up @@ -538,9 +550,13 @@
"source": [
"#| exports\n",
"class ReshapeLongToWide(Callback):\n",
" def __init__(self, columns=['nuclide'], values=['value'], \n",
" num_fill_value=-999, str_fill_value='STR FILL VALUE'):\n",
" \"Convert data from long to wide with renamed columns.\"\n",
" \"Convert data from long to wide with renamed columns.\"\n",
" def __init__(self, \n",
" columns: List[str]=['nuclide'], # Columns to use as index\n",
" values: List[str]=['value'], # Columns to use as values\n",
" num_fill_value: int=-999, # Fill value for numeric columns\n",
" str_fill_value='STR FILL VALUE'\n",
" ):\n",
" fc.store_attr()\n",
" self.derived_cols = self._get_derived_cols()\n",
" \n",
Expand Down Expand Up @@ -616,8 +632,10 @@
"source": [
"#| exports\n",
"class CompareDfsAndTfmCB(Callback):\n",
" def __init__(self, dfs: Dict[str, pd.DataFrame]): \n",
" \"Create a dataframe of dropped data. Data included in the `dfs` not in the `tfm`.\"\n",
" \"Create a dataframe of dropped data. Data included in the `dfs` not in the `tfm`.\"\n",
" def __init__(self, \n",
" dfs: Dict[str, pd.DataFrame] # Original dataframes\n",
" ): \n",
" fc.store_attr()\n",
" \n",
" def __call__(self, tfm: Transformer) -> None:\n",
Expand Down Expand Up @@ -727,8 +745,13 @@
"source": [
"#| exports\n",
"class EncodeTimeCB(Callback):\n",
" \"Encode time as `int` representing seconds since xxx\" \n",
" def __init__(self, cfg , verbose=False): fc.store_attr()\n",
" \"Encode time as `int` representing seconds since xxx.\" \n",
" def __init__(self, \n",
" cfg: dict, # Configuration dictionary\n",
" verbose: bool=False # Whether to print the number of invalid time entries\n",
" ): \n",
" fc.store_attr()\n",
" \n",
" def __call__(self, tfm): \n",
" def format_time(x): \n",
" return date2num(x, units=self.cfg['units']['time'])\n",
Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[DEFAULT]
repo = marisco
lib_name = marisco
version = 0.2.0
version = 0.3.0
min_python = 3.7
license = apache2
doc_path = _docs
Expand Down

0 comments on commit 1c32153

Please sign in to comment.