diff --git a/requirements.txt b/requirements.txt index 4145020..278ff9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,10 @@ wheel pybids scikit-image icecream +Pyarrow +ruamel.yaml +pytest +pytest-cov +pytest-xdist +pytest-env +umap-learn diff --git a/src/mngs/__init__.py b/src/mngs/__init__.py index e898149..7495166 100755 --- a/src/mngs/__init__.py +++ b/src/mngs/__init__.py @@ -1,14 +1,14 @@ #!/usr/bin/env python3 -# Time-stamp: "2024-02-03 16:06:22 (ywatanabe)" +# Time-stamp: "2024-03-08 20:48:06 (ywatanabe)" from . import dsp from . import general from . import general as gen -from . import gists, io, linalg, ml, nn, plt, resource, stats +from . import gists, io, linalg, ml, nn, os, plt, resource, stats from .general.debug import * -__copyright__ = "Copyright (C) 2021 Yusuke Watanabe" -__version__ = "1.0.1" +__copyright__ = "Copyright (C) 2024 Yusuke Watanabe" +__version__ = "1.1.0" __license__ = "GPL3.0" __author__ = "ywatanabe1989" __author_email__ = "ywata1989@gmail.com" diff --git a/src/mngs/dsp/.#fft.py b/src/mngs/dsp/.#fft.py new file mode 120000 index 0000000..44cd1c5 --- /dev/null +++ b/src/mngs/dsp/.#fft.py @@ -0,0 +1 @@ +ywatanabe@ywata-note-win.2006005053169024153 \ No newline at end of file diff --git a/src/mngs/dsp/__init__.py b/src/mngs/dsp/__init__.py index 01ee836..04ad094 100755 --- a/src/mngs/dsp/__init__.py +++ b/src/mngs/dsp/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # from .wavelet import wavelet +from ._psd import psd_torch from .demo_sig import demo_sig_np, demo_sig_torch from .feature_extractors import ( FeatureExtractorTorch, diff --git a/src/mngs/dsp/_psd.py b/src/mngs/dsp/_psd.py new file mode 100755 index 0000000..62041f9 --- /dev/null +++ b/src/mngs/dsp/_psd.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-02-21 08:58:35 (ywatanabe)" + +import torch + + +def psd_torch(signal, samp_rate): + """ + # Example usage: + samp_rate = 480 # Sampling rate in Hz + signal = torch.randn(480) # Example signal with 480 samples + freqs, psd = calculate_psd(signal, samp_rate) + + # Plot the PSD (if you have matplotlib installed) + import matplotlib.pyplot as plt + + plt.plot(freqs.numpy(), psd.numpy()) + plt.xlabel('Frequency (Hz)') + plt.ylabel('Power/Frequency (V^2/Hz)') + plt.title('Power Spectral Density') + plt.show() + """ + # Apply window function to the signal (e.g., Hanning window) + window = torch.hann_window(signal.size(-1)) + signal = signal * window + + # Perform the FFT + fft_output = torch.fft.fft(signal) + + # Compute the power spectrum (magnitude squared of the FFT output) + power_spectrum = torch.abs(fft_output) ** 2 + + # Normalize the power spectrum to get the PSD + # Usually, we divide by the length of the signal and the sum of the window squared + # to get the power in terms of physical units (e.g., V^2/Hz) + psd = power_spectrum / (samp_rate * (window**2).sum()) + + # Since the signal is real, we only need the positive half of the FFT output + # The factor of 2 accounts for the energy in the negative frequencies that we're discarding + psd = psd[: len(psd) // 2] * 2 + + # Adjust the DC component (0 Hz) and Nyquist component (if applicable) + psd[0] /= 2 + if len(psd) % 2 == 0: # Even length, Nyquist freq component is included + psd[-1] /= 2 + + # Frequency axis + freqs = torch.fft.fftfreq(signal.size(-1), 1 / samp_rate)[: len(psd)] + + return freqs, psd diff --git a/src/mngs/dsp/demo_sig.py b/src/mngs/dsp/demo_sig.py index a4ae215..5677707 100755 --- a/src/mngs/dsp/demo_sig.py +++ b/src/mngs/dsp/demo_sig.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Time-stamp: "2024-01-20 13:58:43 (ywatanabe)"#!/usr/bin/env python3 +# Time-stamp: "2024-02-14 20:04:18 (ywatanabe)"#!/usr/bin/env python3 import numpy as np import torch @@ -53,14 +53,59 @@ def demo_sig_np( return data +# def demo_sig_torch( +# batch_size=64, n_chs=19, samp_rate=1000, len_sec=10, freqs_hz=[2, 5, 10] +# ): +# time = torch.arange(0, len_sec, 1 / samp_rate) +# sig = torch.vstack( +# [torch.sin(f * 2 * torch.pi * time) for f in freqs_hz] +# ).sum(dim=0) +# sig = sig.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_chs, 1) +# return sig + + def demo_sig_torch( - batch_size=64, n_chs=19, samp_rate=1000, len_sec=10, freqs_hz=[2, 5, 10] + batch_size=64, n_chs=19, samp_rate=128, len_sec=1, freqs_hz=[2, 5, 10] ): + """ + Generate a batch of demo signals with varying phase shifts and amplitudes, + and add some noise to simulate more realistic data. + + Parameters: + - batch_size: Number of samples in the batch. + - n_chs: Number of channels per sample. + - samp_rate: Sampling rate in Hz. + - len_sec: Length of the signal in seconds. + - freqs_hz: List of frequencies in Hz to include in the signal. + + Returns: + - sig: Tensor of shape (batch_size, n_chs, samp_rate * len_sec) containing the generated signals. + """ time = torch.arange(0, len_sec, 1 / samp_rate) - sig = torch.vstack( - [torch.sin(f * 2 * torch.pi * time) for f in freqs_hz] - ).sum(dim=0) - sig = sig.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_chs, 1) + # Initialize an empty signal tensor + sig = torch.zeros(batch_size, n_chs, len(time)) + + # Loop over each frequency and add a sinusoid with a random phase shift and amplitude + for f in freqs_hz: + # Generate a random phase shift for each sample in the batch + phase_shifts = torch.rand(batch_size, 1, 1) * 2 * torch.pi + # Generate a random amplitude for each sample in the batch + amplitudes = ( + torch.rand(batch_size, 1, 1) * 0.5 + 0.5 + ) # Amplitudes between 0.5 and 1.0 + # Create the sinusoid with the phase shift and amplitude + sinusoid = amplitudes * torch.sin( + f * 2 * torch.pi * time + phase_shifts + ) + # Repeat the sinusoid across channels + sinusoid = sinusoid.repeat(1, n_chs, 1) + # Add the sinusoid to the signal + sig += sinusoid + + # Add some Gaussian noise to the signal + noise = torch.randn_like(sig) * 0.1 # Noise level can be adjusted + sig += noise + return sig diff --git a/src/mngs/general/__init__.py b/src/mngs/general/__init__.py index 3e1523d..f920557 100755 --- a/src/mngs/general/__init__.py +++ b/src/mngs/general/__init__.py @@ -39,10 +39,13 @@ merge_dicts_wo_overlaps, partial_at, pop_keys, + print_block, search, squeeze_spaces, suppress_output, take_the_closest, + unique, + uq, wait_key, ) from .pandas import ( diff --git a/src/mngs/general/_close.py b/src/mngs/general/_close.py index a10db35..3368bbb 100755 --- a/src/mngs/general/_close.py +++ b/src/mngs/general/_close.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Time-stamp: "2024-01-29 15:44:06 (ywatanabe)" +# Time-stamp: "2024-03-08 20:13:59 (ywatanabe)" +import os +import shutil from datetime import datetime, timedelta from glob import glob from time import sleep @@ -23,27 +25,43 @@ def format_diff_time(diff_time): return diff_time_str -def close(CONFIG, message=":)", show=True): +def close(CONFIG, message=":)", notify=True, show=True): try: - end_time = datetime.now() - diff_time = format_diff_time(end_time - CONFIG["START_TIME"]) - CONFIG["TimeSpent"] = diff_time - del CONFIG["START_TIME"] + CONFIG["END_TIME"] = datetime.now() + CONFIG["SPENT_TIME"] = format_diff_time( + CONFIG["END_TIME"] - CONFIG["START_TIME"] + ) + if show: + print(f"\nEND TIME: {CONFIG['END_TIME']}") + print(f"\nSPENT TIME: {CONFIG['SPENT_TIME']}") + except Exception as e: print(e) mngs.io.save(CONFIG, CONFIG["SDIR"] + "CONFIG.pkl") + mngs.io.save(CONFIG, CONFIG["SDIR"] + "CONFIG.yaml") try: if CONFIG.get("DEBUG", False): message = f"[DEBUG]\n" + message sleep(3) - mngs.gen.notify( - message=message, - ID=CONFIG["ID"], - log_paths=glob(CONFIG["SDIR"] + "*.log"), - show=show, - ) + if notify: + mngs.gen.notify( + message=message, + ID=CONFIG["ID"], + log_paths=glob(CONFIG["SDIR"] + "*.log"), + show=show, + ) + except Exception as e: + print(e) + + # RUNNING to FINISHED + src_dir = CONFIG["SDIR"] + dest_dir = src_dir.replace("RUNNING", "FINISHED") + os.makedirs(dest_dir, exist_ok=True) + try: + os.rename(src_dir, dest_dir) + print(f"\nRenamed from: {src_dir} to {dest_dir}") except Exception as e: print(e) @@ -53,8 +71,11 @@ def close(CONFIG, message=":)", show=True): import matplotlib.pyplot as plt import mngs + from icecream import ic - CONFIG, sys.stdout, sys.stderr, plt = mngs.gen.start(sys, plt, show=False) + CONFIG, sys.stdout, sys.stderr, plt, CC = mngs.gen.start( + sys, plt, show=False + ) ic("aaa") ic("bbb") diff --git a/src/mngs/general/_start.py b/src/mngs/general/_start.py index 6d5132a..669543f 100755 --- a/src/mngs/general/_start.py +++ b/src/mngs/general/_start.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Time-stamp: "2024-01-29 15:41:58 (ywatanabe)" +# Time-stamp: "2024-03-08 19:55:48 (ywatanabe)" import inspect import os as _os @@ -25,23 +25,24 @@ def start( tf=None, seed=42, # matplotlib - dpi=100, - save_dpi=300, - figsize=(16.2, 10), - figscale=1.0, - fontsize=16, - labelsize="same", - legendfontsize="xx-small", - tick_size="auto", - tick_width="auto", - hide_spines=False, + fig_size_mm=(160, 100), + fig_scale=1.0, + dpi_display=100, + dpi_save=300, + font_size_base=8, + font_size_title=8, + font_size_axis_label=8, + font_size_tick_label=7, + font_size_legend=6, + hide_top_right_spines=True, + alpha=0.75, ): """ import sys import matplotlib.pyplot as plt import mngs - CONFIG, sys.stdout, sys.stderr, plt = mngs.gen.start(sys, plt) + CONFIG, sys.stdout, sys.stderr, plt, cc = mngs.gen.start(sys, plt) # YOUR CODE HERE @@ -52,13 +53,15 @@ def start( # Debug mode check try: - is_debug = mngs.io.load("./config/is_debug.yaml").get("DEBUG", False) + IS_DEBUG = mngs.io.load("./config/IS_DEBUG.yaml").get( + "IS_DEBUG", False + ) except Exception as e: - is_debug = False + IS_DEBUG = False # ID - ID = mngs.gen.gen_ID() - ID = ID if not is_debug else "[DEBUG] " + ID + ID = mngs.gen.gen_ID(N=4) + ID = ID if not IS_DEBUG else "DEBUG_" + ID print(f"\n{'#'*40}\n## {ID}\n{'#'*40}\n") sleep(1) @@ -69,20 +72,22 @@ def start( __file__ = "/tmp/fake.py" spath = __file__ _sdir, sfname, _ = mngs.general.split_fpath(spath) - sdir = _sdir + sfname + "/" + ID + "/" # " # + "/log/" + sdir = ( + _sdir + sfname + "/" + "RUNNING" + "/" + ID + "/" + ) # " # + "/log/" _os.makedirs(sdir, exist_ok=True) # CONFIGs - CONFIGS = mngs.io.load_configs(is_debug) + CONFIGS = mngs.io.load_configs(IS_DEBUG) CONFIGS["ID"] = ID CONFIGS["START_TIME"] = start_time - CONFIGS["SDIR"] = sdir + CONFIGS["SDIR"] = sdir.replace("/./", "/") if show: print(f"\n{'-'*40}\n") print(f"CONFIG:") for k, v in CONFIGS.items(): print(f"\n{k}:\n{v}\n") - sleep(0.1) + # sleep(0.1) print(f"\n{'-'*40}\n") # Logging (tee) @@ -109,22 +114,23 @@ def start( # Matplotlib configuration if plt is not None: - plt = mngs.plt.configure_mpl( + plt, cc = mngs.plt.configure_mpl( plt, - dpi=dpi, - save_dpi=save_dpi, - figsize=figsize, - figscale=figscale, - fontsize=fontsize, - labelsize=labelsize, - legendfontsize=legendfontsize, - tick_size=tick_size, - tick_width=tick_width, - hide_spines=hide_spines, + fig_size_mm=(160, 100), + fig_scale=fig_scale, + dpi_display=dpi_display, + dpi_save=dpi_save, + font_size_base=font_size_base, + font_size_title=font_size_title, + font_size_axis_label=font_size_axis_label, + font_size_tick_label=font_size_tick_label, + font_size_legend=font_size_legend, + hide_top_right_spines=hide_top_right_spines, + alpha=alpha, show=show, ) - return CONFIGS, sys.stdout, sys.stderr, plt + return CONFIGS, sys.stdout, sys.stderr, plt, cc if __name__ == "__main__": @@ -134,7 +140,7 @@ def start( import mngs # --------------------------------------------------------------------------- # - CONFIG, sys.stdout, sys.stderr, plt = mngs.gen.start(sys, plt) + CONFIG, sys.stdout, sys.stderr, plt, cc = mngs.gen.start(sys, plt) # --------------------------------------------------------------------------- # # YOUR CODE HERE diff --git a/src/mngs/general/load.py b/src/mngs/general/load.py index 160ff42..425e8de 100755 --- a/src/mngs/general/load.py +++ b/src/mngs/general/load.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import json +import logging import os import pickle import warnings @@ -26,109 +27,134 @@ def load(lpath, show=False, **kwargs): - extension = "." + lpath.split(".")[-1] # [REVISED] - - # csv - if extension == ".csv": - obj = pd.read_csv(lpath, **kwargs) - obj = obj.loc[:, ~obj.columns.str.contains("^Unnamed")] # [REVISED] - # tsv - elif extension == ".tsv": - obj = pd.read_csv(lpath, sep="\t", **kwargs) - - # excel - elif extension in [".xls", ".xlsx", ".xlsm", ".xlsb"]: # [REVISED] - obj = pd.read_excel(lpath, **kwargs) - - # numpy - elif extension == ".npy": - obj = np.load(lpath, allow_pickle=True) # [REVISED] - # pkl - elif extension == ".pkl": - with open(lpath, "rb") as l: - obj = pickle.load(l) - # joblib - elif extension == ".joblib": - with open(lpath, "rb") as l: - obj = joblib.load(l) - # hdf5 - elif extension == ".hdf5": - obj = {} - with h5py.File(lpath, "r") as hf: - for name in hf: # [REVISED] - obj[name] = hf[name][:] - # json - elif extension == ".json": - with open(lpath, "r") as f: - obj = json.load(f) - # png - elif extension == ".png": - pass - # tiff - elif extension in [".tiff", ".tif"]: - pass - # yaml - elif extension == ".yaml": - with open(lpath) as f: - obj = yaml.safe_load(f) # [REVISED] - # txt - elif extension in [".txt", ".log"]: - with open(lpath, "r") as f: # [REVISED] - obj = f.read().splitlines() # [REVISED] - # pth - elif extension in [".pth", ".pt"]: - obj = torch.load(lpath) - - # mat - elif extension == ".mat": # [REVISED] - from pymatreader import read_mat # [REVISED] - - obj = read_mat(lpath) # [REVISED] - # xml - elif extension == ".xml": # [REVISED] - from xml2dict import xml2dict # [REVISED] - - obj = xml2dict(lpath) # [REVISED] - # # edf - # elif extension == ".edf": # [REVISED] - # obj = mne.io.read_raw_edf(lpath, preload=True) # [REVISED] - # con - elif extension == ".con": # [REVISED] - obj = mne.io.read_raw_fif(lpath, preload=True) # [REVISED] - obj = obj.to_data_frame() # [REVISED] - obj["samp_rate"] = obj.info["sfreq"] # [REVISED] - # # mrk - # elif extension == ".mrk": # [REVISED] - # obj = mne.io.read_mrk(lpath) # [REVISED] - - # catboost model - elif extension == ".cbm": # [REVISED] - from catboost import CatBoostModel # [REVISED] - - obj = CatBoostModel.load_model(lpath) # [REVISED] - - # EEG data - elif extension in [ - ".vhdr", - ".vmrk", - ".edf", - ".bdf", - ".gdf", - ".cnt", - ".egi", - ".eeg", - ".set", - ]: - obj = load_eeg_data(lpath, **kwargs) - - else: - print(f"\nNot loaded from: {lpath}\n") - return None + try: + extension = "." + lpath.split(".")[-1] # [REVISED] + + # csv + if extension == ".csv": + obj = pd.read_csv(lpath, **kwargs) + obj = obj.loc[ + :, ~obj.columns.str.contains("^Unnamed") + ] # [REVISED] + # tsv + elif extension == ".tsv": + obj = pd.read_csv(lpath, sep="\t", **kwargs) + + # excel + elif extension in [".xls", ".xlsx", ".xlsm", ".xlsb"]: # [REVISED] + obj = pd.read_excel(lpath, **kwargs) + + # numpy + elif extension == ".npy": + obj = np.load(lpath, allow_pickle=True, **kwargs) # [REVISED] + # pkl + elif extension == ".pkl": + with open(lpath, "rb") as l: + obj = pickle.load(l, **kwargs) + # joblib + elif extension == ".joblib": + with open(lpath, "rb") as l: + obj = joblib.load(l, **kwargs) + # hdf5 + elif extension == ".hdf5": + obj = {} + with h5py.File(lpath, "r") as hf: + for name in hf: # [REVISED] + obj[name] = hf[name][:] + # json + elif extension == ".json": + with open(lpath, "r") as f: + obj = json.load(f) + # png + elif extension == ".png": + pass + # tiff + elif extension in [".tiff", ".tif"]: + pass + # yaml + elif extension == ".yaml": + # from ruamel.yaml import YAML + + # yaml = YAML() + # yaml.preserve_quotes = ( + # True # Optional: if you want to preserve quotes + # ) + # yaml.indent( + # mapping=2, sequence=4, offset=2 + # ) # Optional: set indentation + + lower = kwargs.pop("lower", False) + + with open(lpath) as f: + obj = yaml.safe_load(f, **kwargs) # [REVISED] + + if lower: + obj = {k.lower(): v for k, v in obj.items()} + + # txt + elif extension in [".txt", ".log", ".event"]: + with open(lpath, "r") as f: # [REVISED] + obj = f.read().splitlines() # [REVISED] + # pth + elif extension in [".pth", ".pt"]: + obj = torch.load(lpath, **kwargs) + + # mat + elif extension == ".mat": # [REVISED] + from pymatreader import read_mat # [REVISED] + + obj = read_mat(lpath, **kwargs) # [REVISED] + # xml + elif extension == ".xml": # [REVISED] + from xml2dict import xml2dict # [REVISED] + + obj = xml2dict(lpath, **kwargs) # [REVISED] + # # edf + # elif extension == ".edf": # [REVISED] + # obj = mne.io.read_raw_edf(lpath, preload=True) # [REVISED] + # con + elif extension == ".con": # [REVISED] + obj = mne.io.read_raw_fif( + lpath, preload=True, **kwargs + ) # [REVISED] + obj = obj.to_data_frame() # [REVISED] + obj["samp_rate"] = obj.info["sfreq"] # [REVISED] + # # mrk + # elif extension == ".mrk": # [REVISED] + # obj = mne.io.read_mrk(lpath) # [REVISED] + + # catboost model + elif extension == ".cbm": # [REVISED] + from catboost import CatBoostModel # [REVISED] + + obj = CatBoostModel.load_model(lpath, **kwargs) # [REVISED] + + # EEG data + elif extension in [ + ".vhdr", + ".vmrk", + ".edf", + ".bdf", + ".gdf", + ".cnt", + ".egi", + ".eeg", + ".set", + ]: + obj = load_eeg_data(lpath, **kwargs) + + else: + print(f"\nNot loaded from: {lpath}\n") + return None - if show: - print(f"\nLoaded from: {lpath}\n") + if show: + print(f"\nLoaded from: {lpath}\n") - return obj + return obj + + except Exception as e: + logging.error(f"\n{lpath} was not loaded:\n{e}") + return None def load_eeg_data(filename, **kwargs): @@ -262,28 +288,32 @@ def load_study_rdb(study_name, rdb_raw_bytes_url): return study -def load_configs(is_debug=None): - def update_debug(config, is_debug): - if is_debug: +def load_configs(IS_DEBUG=None, show=False): + if os.getenv("CI") == "true": + IS_DEBUG = True + + def update_debug(config, IS_DEBUG): + if IS_DEBUG: debug_keys = mngs.gen.search("^DEBUG_", list(config.keys()))[1] for dk in debug_keys: dk_wo_debug_prefix = dk.split("DEBUG_")[1] config[dk_wo_debug_prefix] = config[dk] - print(f"\n{dk} -> {dk_wo_debug_prefix}\n") + if show: + print(f"\n{dk} -> {dk_wo_debug_prefix}\n") return config - # Check ./config/is_debug.yaml file if is_debug argument is not passed - if is_debug is None: + # Check ./config/IS_DEBUG.yaml file if IS_DEBUG argument is not passed + if IS_DEBUG is None: try: - is_debug = mngs.io.load("./config/is_debug.yaml")["DEBUG"] + IS_DEBUG = mngs.io.load("./config/IS_DEBUG.yaml")["IS_DEBUG"] except Exception as e: print(e) - is_debug = False + IS_DEBUG = False # Main CONFIGS = {} for lpath in glob("./config/*.yaml"): - CONFIG = update_debug(mngs.io.load(lpath), is_debug) + CONFIG = update_debug(mngs.io.load(lpath), IS_DEBUG) CONFIGS.update(CONFIG) return CONFIGS diff --git a/src/mngs/general/misc.py b/src/mngs/general/misc.py index e01c9d0..1833c5a 100755 --- a/src/mngs/general/misc.py +++ b/src/mngs/general/misc.py @@ -574,3 +574,122 @@ def suppress_output(): with contextlib.redirect_stderr(fnull): # Yield control back to the context block yield + + +def unique(data, axis=None): + """ + Identifies unique elements in the data along the specified axis and their counts, returning a DataFrame. + + Parameters: + - data (array-like): The input data to analyze for unique elements. + - axis (int, optional): The axis along which to find the unique elements. Defaults to None. + + Returns: + - df (pandas.DataFrame): DataFrame with unique elements and their counts. + """ + if axis is None: + uqs, counts = np.unique(data, return_counts=True) + else: + uqs, counts = np.unique(data, axis=axis, return_counts=True) + + if axis is None: + df = pd.DataFrame({"uq": uqs, "n": counts}) + else: + df = pd.DataFrame( + uqs, columns=[f"axis_{i}" for i in range(uqs.shape[1])] + ) + df["n"] = counts + + df["n"] = df["n"].apply(lambda x: f"{int(x):,}") + + return df + + +def unique(data, axis=None): + """ + Identifies unique elements in the data along the specified axis and their counts, returning a DataFrame. + + Parameters: + - data (array-like): The input data to analyze for unique elements. + - axis (int, optional): The axis along which to find the unique elements. Defaults to None. + + Returns: + - df (pandas.DataFrame): DataFrame with unique elements and their counts. + """ + # Find unique elements and their counts + if axis is None: + uqs, counts = np.unique(data, return_counts=True) + df = pd.DataFrame({"Unique Elements": uqs, "Counts": counts}) + else: + uqs, counts = np.unique(data, axis=axis, return_counts=True) + # Create a DataFrame with unique elements + df = pd.DataFrame( + uqs, + columns=[f"Unique Elements Axis {i}" for i in range(uqs.shape[1])], + ) + # Add a column for counts + df["Counts"] = counts + + # Format the 'Counts' column with commas for thousands + df["Counts"] = df["Counts"].apply(lambda x: f"{x:,}") + + return df + + +def uq(*args, **kwargs): + return unique(*args, **kwargs) + + +# def uq(data, axis=None): +# def _uq(data): +# uqs, counts = np.unique(data, return_counts=True) +# df = pd.DataFrame({"uq": uqs, "n": counts}) +# # Format the 'Counts' column with commas for thousands +# df["n"] = df["n"].apply(lambda x: f"{x:,}") +# return df + +# data = pd.DataFrame(data) + +# if axis == 1: +# dfs = {} +# for col in data.columns: +# df = _uq(data[col]) +# dfs[col] = df +# return dfs + +# if axis == 0: +# dfs = {} +# for col in data.T.columns: +# df = _uq(data.T[col]) +# dfs[col] = df +# return dfs + +# if axis is None: +# return _uq(data) + + +# def unique(data, axis=None): +# """ +# Identifies unique elements in the data and their counts, returning a DataFrame. + +# Parameters: +# - data (array-like): The input data to analyze for unique elements. +# - show (bool, optional): If True, prints the DataFrame. Defaults to True. + +# Returns: +# - df (pandas.DataFrame): DataFrame with unique elements and their counts. +# """ +# uqs, counts = np.unique(data, return_counts=True) # [REVISED] +# df = pd.DataFrame( +# np.vstack([uqs, counts]).T, columns=["uq", "n"] # [REVISED] +# ).set_index( +# "uq" +# ) # [REVISED] + +# df_show = df.copy() +# df_show["n"] = df_show["n"].apply(lambda x: f"{int(x):,}") # [REVISED] + +# return df_show +def print_block(message, char="-", n=40): + border = char * n + print(f"\n{border}\n{message}\n{border}\n") diff --git a/src/mngs/general/path.py b/src/mngs/general/path.py index 8e58719..8335d97 100755 --- a/src/mngs/general/path.py +++ b/src/mngs/general/path.py @@ -1,12 +1,15 @@ #!/usr/bin/env python3 +import fnmatch +import glob import inspect import os - -import mngs +import re +import subprocess import warnings from glob import glob +import mngs if "general" in __file__: with warnings.catch_warnings(): @@ -68,11 +71,119 @@ def split_fpath(fpath): fname, ext = os.path.splitext(base) return dirname, fname, ext + def touch(fpath): import pathlib + return pathlib.Path(fpath).touch() - -def find(directory, pattern): - search_pattern = os.path.join(directory, '**', pattern) - return glob(search_pattern, recursive=True) - + + +def find(rootdir, type="f", exp=["*"]): + """ + Mimicks the Unix find command. + + Example: + # rootdir = + # type = 'f' # 'f' for files, 'd' for directories, None for both + # exp = '*.txt' # Pattern to match, or None to match all + find('/path/to/search', "f", "*.txt") + """ + if isinstance(exp, str): + exp = [exp] + + matches = [] + for _exp in exp: + for root, dirs, files in os.walk(rootdir): + # Depending on the type, choose the list to iterate over + if type == "f": # Files only + names = files + elif type == "d": # Directories only + names = dirs + else: # All entries + names = files + dirs + + for name in names: + # Construct the full path + path = os.path.join(root, name) + + # If an _exp is provided, use fnmatch to filter names + if _exp and not fnmatch.fnmatch(name, _exp): + continue + + # If type is set, ensure the type matches + if type == "f" and not os.path.isfile(path): + continue + if type == "d" and not os.path.isdir(path): + continue + + # Add the matching path to the results + matches.append(path) + + return matches + + +def find_latest(dirname, fname, ext, version_prefix="_v"): + version_pattern = re.compile( + rf"({re.escape(fname)}{re.escape(version_prefix)})(\d+)({re.escape(ext)})$" + ) + + glob_pattern = os.path.join(dirname, f"{fname}{version_prefix}*{ext}") + files = glob(glob_pattern) + + highest_version = 0 + latest_file = None + + for file in files: + filename = os.path.basename(file) + match = version_pattern.search(filename) + if match: + version_num = int(match.group(2)) + if version_num > highest_version: + highest_version = version_num + latest_file = file + + return latest_file + + +def increment_version(dirname, fname, ext, version_prefix="_v"): + # Create a regex pattern to match the version number in the filename + version_pattern = re.compile( + rf"({re.escape(fname)}{re.escape(version_prefix)})(\d+)({re.escape(ext)})$" + ) + + # Construct the glob pattern to find all files that match the pattern + glob_pattern = os.path.join(dirname, f"{fname}{version_prefix}*{ext}") + + # Use glob to find all files that match the pattern + files = glob(glob_pattern) + + # Initialize the highest version number + highest_version = 0 + base, suffix = None, None + + # Loop through the files to find the highest version number + for file in files: + filename = os.path.basename(file) + match = version_pattern.search(filename) + if match: + base, version_str, suffix = match.groups() + version_num = int(version_str) + if version_num > highest_version: + highest_version = version_num + + # If no versioned files were found, use the provided filename and extension + if base is None or suffix is None: + base = f"{fname}{version_prefix}" + suffix = ext + highest_version = 0 # No previous versions + + # Increment the highest version number + next_version_number = highest_version + 1 + + # Format the next version number with the same number of digits as the original + next_version_str = f"{base}{next_version_number:03d}{suffix}" + + # Combine the directory and new filename to create the full path + next_filepath = os.path.join(dirname, next_version_str) + + return next_filepath diff --git a/src/mngs/general/repro.py b/src/mngs/general/repro.py index ad765d2..19f06a4 100755 --- a/src/mngs/general/repro.py +++ b/src/mngs/general/repro.py @@ -46,13 +46,15 @@ def fix_seeds( print(f"\n{'-'*40}\n") -def gen_ID(N=8): +def gen_ID(time_format="%YY-%mM-%dD-%Hh%Mm%Ss", N=8): import random import string from datetime import datetime now = datetime.now() - now_str = now.strftime("%Y-%m-%d-%H-%M") + # now_str = now.strftime("%Y-%m-%d-%H-%M") + now_str = now.strftime(time_format) + # today_str = now.strftime("%Y-%m%d") randlst = [ random.choice(string.ascii_letters + string.digits) for i in range(N) diff --git a/src/mngs/general/save.py b/src/mngs/general/save.py index 49c55a3..2557533 100755 --- a/src/mngs/general/save.py +++ b/src/mngs/general/save.py @@ -1,15 +1,13 @@ #!/usr/bin/env python3 import csv +import warnings -import pandas as pd import mngs import numpy as np +import pandas as pd import scipy -import warnings - - if "general" in __file__: with warnings.catch_warnings(): warnings.simplefilter("always") @@ -28,6 +26,7 @@ def save(obj, sfname_or_spath, makedirs=True, show=True, **kwargs): save(serializable, 'serializable.pkl') """ import inspect + import json import os import pickle @@ -65,9 +64,9 @@ def save(obj, sfname_or_spath, makedirs=True, show=True, **kwargs): ## Saves try: ## copy files - is_copying_files = (isinstance(obj, str) or is_listed_X(obj, str)) and ( - isinstance(spath, str) or is_listed_X(spath, str) - ) + is_copying_files = ( + isinstance(obj, str) or is_listed_X(obj, str) + ) and (isinstance(spath, str) or is_listed_X(spath, str)) if is_copying_files: mngs.general.copy_files(obj, spath) @@ -117,10 +116,27 @@ def save(obj, sfname_or_spath, makedirs=True, show=True, **kwargs): elif spath.endswith(".mp4"): mk_mp4(obj, spath) # obj is matplotlib.pyplot.figure object del obj + # yaml elif spath.endswith(".yaml"): + from ruamel.yaml import YAML + + yaml = YAML() + yaml.preserve_quotes = ( + True # Optional: if you want to preserve quotes + ) + yaml.indent( + mapping=4, sequence=4, offset=4 + ) # Optional: set indentation + with open(spath, "w") as f: yaml.dump(obj, f) + + # json + elif spath.endswith(".json"): + with open(spath, "w") as f: + json.dump(obj, f, indent=4) + # hdf5 elif spath.endswith(".hdf5"): name_list, obj_list = [] @@ -212,7 +228,9 @@ def save_listed_scalars_as_csv( if overwrite == True: mv_to_tmp(spath_csv, L=2) - indi_suffix = np.arange(len(listed_scalars)) if indi_suffix is None else indi_suffix + indi_suffix = ( + np.arange(len(listed_scalars)) if indi_suffix is None else indi_suffix + ) df = pd.DataFrame( {"{}".format(column_name): listed_scalars}, index=indi_suffix ).round(round) @@ -244,7 +262,9 @@ def save_listed_dfs_as_csv( if overwrite == True: mv_to_tmp(spath_csv, L=2) - indi_suffix = np.arange(len(listed_dfs)) if indi_suffix is None else indi_suffix + indi_suffix = ( + np.arange(len(listed_dfs)) if indi_suffix is None else indi_suffix + ) for i, df in enumerate(listed_dfs): with open(spath_csv, mode="a") as f: f_writer = csv.writer(f) @@ -275,7 +295,9 @@ def animate(i): fig, animate, init_func=init, frames=360, interval=20, blit=True ) - writermp4 = animation.FFMpegWriter(fps=60, extra_args=["-vcodec", "libx264"]) + writermp4 = animation.FFMpegWriter( + fps=60, extra_args=["-vcodec", "libx264"] + ) anim.save(spath_mp4, writer=writermp4) print("\nSaving to: {}\n".format(spath_mp4)) @@ -302,12 +324,16 @@ def save_optuna_study_as_csv_and_pngs(study, sdir): ## Figures hparams_keys = list(study.best_params.keys()) slice_plot = optuna.visualization.plot_slice(study, params=hparams_keys) - contour_plot = optuna.visualization.plot_contour(study, params=hparams_keys) + contour_plot = optuna.visualization.plot_contour( + study, params=hparams_keys + ) optim_hist_plot = optuna.visualization.plot_optimization_history(study) parallel_coord_plot = optuna.visualization.plot_parallel_coordinate( study, params=hparams_keys ) - hparam_importances_plot = optuna.visualization.plot_param_importances(study) + hparam_importances_plot = optuna.visualization.plot_param_importances( + study + ) figs_dict = dict( slice_plot=slice_plot, contour_plot=contour_plot, diff --git a/src/mngs/io/__init__.py b/src/mngs/io/__init__.py index 4a97e7d..daf7a10 100755 --- a/src/mngs/io/__init__.py +++ b/src/mngs/io/__init__.py @@ -9,8 +9,10 @@ ) from .path import ( find, + find_latest, find_the_git_root_dir, get_this_fpath, + increment_version, mk_spath, split_fpath, touch, diff --git a/src/mngs/ml/__init__.py b/src/mngs/ml/__init__.py index 8ad8996..5c90946 100755 --- a/src/mngs/ml/__init__.py +++ b/src/mngs/ml/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from . import act, layer, optim, plt, utils +from . import act, clustering, layer, metrics, optim, plt, sk, utils from .ClassificationReporter import ( ClassificationReporter, MultiClassificationReporter, diff --git a/src/mngs/ml/clustering/_UMAP.py b/src/mngs/ml/clustering/_UMAP.py new file mode 100755 index 0000000..e1f18bc --- /dev/null +++ b/src/mngs/ml/clustering/_UMAP.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-02-11 22:10:09 (ywatanabe)" + +# Assuming the existence of an 'mngs/ml/cluster.py' module, add the following: + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import umap.umap_ as umap # Ensure to import UMAP correctly based on your installation + + +class UMAP: + """ + A class to encapsulate UMAP clustering within the mngs.ml.cluster module. + """ + + @staticmethod + def cluster(data, labels=None, supervised=False): + """ + Performs UMAP clustering on the given data, with options for supervised or unsupervised learning, + and visualizes the result. + + Parameters: + - data (np.ndarray): The input data for clustering. + - labels (np.array, optional): Labels for each data point, used in supervised mode. + - supervised (bool, default=False): If True, performs supervised clustering using the provided labels. + + Returns: + - fig (matplotlib.figure.Figure): The figure object for the UMAP visualization. + - embedding (np.ndarray): The 2D embedding of the input data after UMAP reduction. + """ + if supervised and labels is None: + raise ValueError("Labels are required for supervised learning.") + + umap_model = umap.UMAP() + + if supervised: + embedding = umap_model.fit_transform(data, y=labels) + else: + embedding = umap_model.fit_transform(data) + + fig = plt.figure(figsize=(10, 8)) + if labels is not None: + unique_labels = np.unique(labels) + palette = sns.color_palette("hsv", len(unique_labels)) + for i, label in enumerate(unique_labels): + indices = labels == label + plt.scatter( + embedding[indices, 0], + embedding[indices, 1], + label=label, + s=5, + color=palette[i], + ) + plt.legend(markerscale=3.0, title="Labels") + else: + plt.scatter(embedding[:, 0], embedding[:, 1], s=5) + + plt.title("UMAP Clustering") + plt.xlabel("UMAP 1") + plt.ylabel("UMAP 2") + + return fig, embedding + + +# Example usage in your library context: +# from mngs.ml.cluster import UMAP + +# data, labels = , +# fig, clustered = UMAP.cluster(data, labels, supervised=False) +# plt.show() # Display the clustering result visualization diff --git a/src/mngs/ml/clustering/__init__.py b/src/mngs/ml/clustering/__init__.py new file mode 100755 index 0000000..6febed7 --- /dev/null +++ b/src/mngs/ml/clustering/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-02 23:20:10 (ywatanabe)" + +from ._umap import umap diff --git a/src/mngs/ml/clustering/_umap.py b/src/mngs/ml/clustering/_umap.py new file mode 100755 index 0000000..eceab80 --- /dev/null +++ b/src/mngs/ml/clustering/_umap.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-03 02:12:25 (ywatanabe)" + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import umap.umap_ as umap_orig +from sklearn.preprocessing import LabelEncoder + + +def umap( + data_all, + labels_all, + axes_titles=None, + supervised=False, + title="UMAP Clustering", + alpha=0.1, + s=3, +): + + assert len(data_all) == len(labels_all) + + if isinstance(data_all, list): + data_all = list(data_all) + labels_all = list(labels_all) + + le = LabelEncoder() + + labels_all_orig = [np.array(labels) for labels in labels_all] + + le.fit(np.hstack(labels_all)) + labels_all = [le.transform(labels) for labels in labels_all] + + umap_model = umap_orig.UMAP(random_state=42) + + if supervised: + _umap = umap_model.fit(data_all[0], y=labels_all[0]) + title = f"(Supervised) {title}" + else: + _umap = umap_model.fit(data_all[0]) + title = f"(Unsupervised) {title}" + + fig, axes = plt.subplots(ncols=len(data_all) + 1, sharex=True, sharey=True) + + for ii, (data, labels) in enumerate(zip(data_all, labels_all)): + embedding = _umap.transform(data) + + ax = axes[ii + 1] + + axes[0].set_title("Superimposed") + axes[0].set_aspect("equal") + palette = "viridis" + + sns.scatterplot( + x=embedding[:, 0], + y=embedding[:, 1], + hue=le.inverse_transform(labels), + ax=axes[0], + palette=palette, + legend="full" if ii == 0 else False, + s=s, + alpha=alpha, + ) + + sns.scatterplot( + x=embedding[:, 0], + y=embedding[:, 1], + hue=le.inverse_transform(labels), + ax=ax, + palette=palette, + s=s, + alpha=alpha, + ) + ax.set_aspect("equal") + + if axes_titles is not None: + ax.set_title(axes_titles[ii]) + + legend_figs = [] + for i, ax in enumerate(axes): + legend = ax.get_legend() + if legend: + legend_fig = plt.figure(figsize=(3, 2)) + new_legend = legend_fig.gca().legend( + handles=legend.legendHandles, labels=legend.texts, loc="center" + ) + legend_fig.canvas.draw() + legend_filename = f"legend_{i}.png" + legend_fig.savefig(legend_filename, bbox_inches="tight") + legend_figs.append(legend_fig) + plt.close(legend_fig) + + for ax in axes: + ax.legend_ = None + # ax.remove_legend() + + fig.suptitle(title) + fig.supxlabel("UMAP 1") + fig.supylabel("UMAP 2") + + return fig, legend_figs, _umap + + +# def umap( +# data_all, +# labels_all, +# axes_titles=None, +# supervised=False, +# title="UMAP Clustering", +# alpha=0.1, +# s=3, +# ): +# """ +# Performs UMAP clustering on the given data and labels, and generates a plot with the results. + +# Parameters: +# - data_all (list of array-like): List of datasets to be used for UMAP clustering. +# - labels_all (list of array-like): List of label arrays corresponding to the datasets. +# - axes_titles (list of str, optional): Titles for each subplot axis. +# - supervised (bool, optional): Whether to use supervised dimensionality reduction. Defaults to False. +# - title (str, optional): Title for the entire plot. Defaults to "UMAP Clustering". +# - alpha (float, optional): Alpha value for the scatter plot points. Defaults to 0.1. +# - s (int, optional): Size of the scatter plot points. Defaults to 3. + +# Returns: +# - fig (matplotlib.figure.Figure): The main figure object containing the UMAP plots. +# - legend_figs (list of matplotlib.figure.Figure): List of figures containing the legends. +# - _umap (umap.UMAP): The fitted UMAP model. +# """ +# import matplotlib.pyplot as plt +# import numpy as np +# import seaborn as sns +# import umap.umap_ as umap_orig +# from sklearn.preprocessing import LabelEncoder + +# assert len(data_all) == len(labels_all) + +# if isinstance(data_all, list): +# data_all = list(data_all) +# labels_all = list(labels_all) + +# le = LabelEncoder() + +# labels_all_orig = [np.array(labels) for labels in labels_all] +# if labels_all is not None: +# labels_all = [le.fit_transform(labels) for labels in labels_all] + +# umap_model = umap_orig.UMAP(random_state=42) + +# if supervised: +# _umap = umap_model.fit(data_all[0], y=labels_all[0]) +# title = f"(Supervised) {title}" +# else: +# _umap = umap_model.fit(data_all[0]) +# title = f"(Unsupervised) {title}" + +# fig, axes = plt.subplots(ncols=len(data_all) + 1, sharex=True, sharey=True) + +# for ii, (data, labels) in enumerate(zip(data_all, labels_all)): +# embedding = _umap.transform(data) +# ax = axes[ii + 1] + +# axes[0].set_title("Superimposed") +# axes[0].set_aspect("equal") +# palette = "viridis" + +# sns.scatterplot( +# x=embedding[:, 0], +# y=embedding[:, 1], +# hue=le.inverse_transform(labels), +# ax=axes[0], +# palette=palette, +# legend="full" if ii == 0 else False, +# s=s, +# alpha=alpha, +# ) + +# sns.scatterplot( +# x=embedding[:, 0], +# y=embedding[:, 1], +# hue=le.inverse_transform(labels), +# ax=ax, +# palette=palette, +# s=s, +# alpha=alpha, +# ) +# ax.set_aspect("equal") + +# if axes_titles is not None: +# ax.set_title(axes_titles[ii]) + +# legend_figs = [] +# for i, ax in enumerate(axes): +# legend = ax.get_legend() +# if legend: +# legend_fig = plt.figure(figsize=(3, 2)) +# new_legend = legend_fig.gca().add_artist(legend) # [REVISED] +# legend.set_bbox_to_anchor((0, 0, 1, 1)) +# legend_fig.canvas.draw() +# legend_filename = f"legend_{i}.png" +# legend_fig.savefig(legend_filename, bbox_inches="tight") +# legend_figs.append(legend_fig) +# plt.close(legend_fig) + +# fig.suptitle(title) +# fig.supxlabel("UMAP 1") +# fig.supylabel("UMAP 2") + +# return fig, legend_figs, _umap + + +# # def umap( +# # data_all, +# # labels_all, +# # axes_titles=None, +# # supervised=False, +# # title="UMAP Clustering", +# # alpha=0.1, +# # s=3, +# # ): + +# # assert len(data_all) == len(labels_all) + +# # if isinstance(data_all, list): +# # data_all = list(data_all) +# # labels_all = list(labels_all) + +# # le = LabelEncoder() + +# # # Store original labels +# # labels_all_orig = [np.array(labels) for labels in labels_all] +# # if labels_all is not None: +# # labels_all = [le.fit_transform(labels) for labels in labels_all] + +# # umap_model = umap_orig.UMAP(random_state=42) + +# # # Process the primary dataset +# # if supervised: +# # _umap = umap_model.fit(data_all[0], y=labels_all[0]) +# # title = f"(Supervised) {title}" +# # else: +# # _umap = umap_model.fit(data_all[0]) +# # title = f"(Unsupervised) {title}" + +# # fig, axes = plt.subplots(ncols=len(data_all) + 1, sharex=True, sharey=True) + +# # for ii, (data, labels) in enumerate(zip(data_all, labels_all)): +# # embedding = _umap.transform(data) +# # ax = axes[ii + 1] + +# # # Superimposed +# # axes[0].set_title("Superimposed") +# # axes[0].set_aspect("equal") +# # palette = "viridis" + +# # sns.scatterplot( +# # x=embedding[:, 0], +# # y=embedding[:, 1], +# # hue=le.inverse_transform(labels), +# # ax=axes[0], +# # palette=palette, +# # legend="full" if ii == 0 else False, +# # s=s, +# # alpha=alpha, +# # ) + +# # # Each data +# # sns.scatterplot( +# # x=embedding[:, 0], +# # y=embedding[:, 1], +# # hue=le.inverse_transform(labels), +# # ax=ax, +# # palette=palette, +# # s=s, +# # alpha=alpha, +# # ) +# # ax.set_aspect("equal") + +# # if axes_titles is not None: +# # ax.set_title(axes_titles[ii]) + +# # # Save legends as separate figures and store them in a list +# # legend_figs = [] +# # for i, ax in enumerate(axes): +# # # Extract the legend from the current axis +# # legend = ax.get_legend() +# # if legend: +# # # Create a new figure for the legend +# # legend_fig = plt.figure(figsize=(3, 2)) +# # new_legend = legend_fig._get_axes().add_artist(legend) +# # legend.set_bbox_to_anchor((0, 0, 1, 1)) +# # legend_fig.canvas.draw() +# # # Save the legend as a PNG file +# # legend_filename = f"legend_{i}.png" +# # legend_fig.savefig(legend_filename, bbox_inches="tight") +# # # Store the legend figure in the list +# # legend_figs.append(legend_fig) +# # plt.close(legend_fig) # Close the figure to free memory + +# # fig.suptitle(title) +# # fig.supxlabel("UMAP 1") +# # fig.supylabel("UMAP 2") + +# # # Return the main figure, the list of legend figures, and the UMAP model +# # return fig, legend_figs, _umap + + +# # # from itertools import cycle + +# # # import matplotlib.pyplot as plt +# # # import mngs +# # # import numpy as np +# # # import seaborn as sns +# # # import umap.umap_ as umap_orig +# # # from sklearn.preprocessing import LabelEncoder + + +# # # def umap( +# # # data_all, +# # # labels_all, +# # # axes_titles=None, +# # # supervised=False, +# # # title="UMAP Clustering", +# # # alpha=0.1, +# # # s=3, +# # # # colors=None, +# # # ): + +# # # assert len(data_all) == len(labels_all) + +# # # if isinstance(data_all, list): +# # # data_all = list(data_all) +# # # labels_all = list(labels_all) + +# # # le = mngs.ml.utils.LabelEncoder() +# # # # le = LabelEncoder() + +# # # # Store original labels +# # # labels_all_orig = [np.array(labels) for labels in labels_all] +# # # if labels_all is not None: +# # # # labels_uq = np.unique(np.hstack(labels_all)) +# # # # le.fit(labels_uq) +# # # labels_all = [le.fit_transform(labels) for labels in labels_all] + +# # # umap_model = umap_orig.UMAP(random_state=42) + +# # # # Process the primary dataset +# # # if supervised: +# # # _umap = umap_model.fit(data_all[0], y=labels_all[0]) +# # # title = f"(Supervised) {title}" +# # # else: +# # # _umap = umap_model.fit(data_all[0]) +# # # title = f"(Unsupervised) {title}" + +# # # fig, axes = plt.subplots(ncols=len(data_all) + 1, sharex=True, sharey=True) +# # # # # Create a color palette that maps each unique label to a color +# # # # unique_labels = np.unique(np.hstack(labels_all_orig)) +# # # # if colors is not None: +# # # # palette = dict(zip(unique_labels, cycle(colors))) +# # # # else: +# # # # palette = sns.color_palette("hsv", len(unique_labels)) +# # # # palette = dict(zip(unique_labels, palette)) + +# # # # if colors is not None: +# # # # color_cycle = cycle(colors) +# # # # else: +# # # # color_cycle = None + +# # # # for ii, (data, labels, labels_orig) in enumerate( +# # # # zip(data_all, labels_all, labels_all_orig) +# # # # ): +# # # for ii, (data, labels) in enumerate(zip(data_all, labels_all)): +# # # embedding = _umap.transform(data) +# # # ax = axes[ii + 1] + +# # # # Superimposed +# # # axes[0].set_title("Superimposed") +# # # axes[0].set_aspect("equal") +# # # palette = "viridis" +# # # # if color_cycle: +# # # # palette = sns.color_palette( +# # # # [next(color_cycle) for _ in range(len(np.unique(labels)))] +# # # # ) +# # # # else: +# # # # palette = "viridis" + +# # # sns.scatterplot( +# # # x=embedding[:, 0], +# # # y=embedding[:, 1], +# # # hue=le.inverse_transform(labels), +# # # ax=axes[0], +# # # palette=palette, +# # # legend="full" if ii == 0 else False, +# # # s=s, +# # # alpha=alpha, +# # # ) + +# # # # Each data +# # # sns.scatterplot( +# # # x=embedding[:, 0], +# # # y=embedding[:, 1], +# # # hue=le.inverse_transform(labels), +# # # ax=ax, +# # # palette=palette, +# # # s=s, +# # # alpha=alpha, +# # # ) +# # # ax.set_aspect("equal") + +# # # if axes_titles is not None: +# # # ax.set_title(axes_titles[ii]) + +# # # # Remove the legends from the individual axes +# # # for ax in axes: +# # # ax.legend(loc="upper left") +# # # legend = ax.get_legend() +# # # if legend: +# # # legend_fig = plt.figure(figsize=(3, 2)) + + +# # # fig.suptitle(title) +# # # fig.supxlabel("UMAP 1") +# # # fig.supylabel("UMAP 2") + +# # # return fig, _umap diff --git a/src/mngs/ml/metrics/__init__.py b/src/mngs/ml/metrics/__init__.py new file mode 100755 index 0000000..2ca340e --- /dev/null +++ b/src/mngs/ml/metrics/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-02-26 16:15:43 (ywatanabe)" + +from ._bACC import bACC diff --git a/src/mngs/ml/metrics/_bACC.py b/src/mngs/ml/metrics/_bACC.py new file mode 100755 index 0000000..f646b8b --- /dev/null +++ b/src/mngs/ml/metrics/_bACC.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-02-26 16:32:42 (ywatanabe)" + +import warnings + +import numpy as np +import torch +from sklearn.metrics import balanced_accuracy_score + + +def bACC(true_class, pred_class): + """ + Calculates the balanced accuracy score between predicted and true class labels. + + Parameters: + - true_class (array-like or torch.Tensor): True class labels. + - pred_class (array-like or torch.Tensor): Predicted class labels. + + Returns: + - bACC (float): The balanced accuracy score rounded to three decimal places. + """ + if isinstance(true_class, torch.Tensor): # [REVISED] + true_class = true_class.detach().cpu().numpy() # [REVISED] + if isinstance(pred_class, torch.Tensor): # [REVISED] + pred_class = pred_class.detach().cpu().numpy() # [REVISED] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + bACC_score = balanced_accuracy_score( + true_class.reshape(-1), # [REVISED] + pred_class.reshape(-1), # [REVISED] + ) + return round(bACC_score, 3) # [REVISED] diff --git a/src/mngs/ml/plt/__init__.py b/src/mngs/ml/plt/__init__.py index 85b326d..8c94c17 100755 --- a/src/mngs/ml/plt/__init__.py +++ b/src/mngs/ml/plt/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +from ._conf_mat import conf_mat from ._learning_curve import learning_curve from .aucs.pre_rec_auc import pre_rec_auc from .aucs.roc_auc import roc_auc -from .confusion_matrix import confusion_matrix diff --git a/src/mngs/ml/plt/confusion_matrix.py b/src/mngs/ml/plt/_conf_mat.py similarity index 66% rename from src/mngs/ml/plt/confusion_matrix.py rename to src/mngs/ml/plt/_conf_mat.py index 01dda8e..265ab60 100755 --- a/src/mngs/ml/plt/confusion_matrix.py +++ b/src/mngs/ml/plt/_conf_mat.py @@ -1,34 +1,45 @@ #!/usr/bin/env python3 import matplotlib +import mngs import numpy as np import pandas as pd import seaborn as sns -import mngs from matplotlib import ticker from mpl_toolkits.axes_grid1 import make_axes_locatable +from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix -def confusion_matrix( +def conf_mat( plt, - cm, + cm=None, + y_true=None, + y_pred=None, + y_pred_proba=None, labels=None, + sorted_labels=None, pred_labels=None, + sorted_pred_labels=None, true_labels=None, - label_rotation_xy=(0, 0), + sorted_true_labels=None, + label_rotation_xy=(15, 15), title=None, colorbar=True, x_extend_ratio=1.0, - y_extend_ratio=1.0, + y_extend_ratio=1.0, ): """ Inverse the y-axis and plot the confusion matrix as a heatmap. The predicted labels (in x-axis) is symboled with hat (^). The plt object is passed to adjust the figure size + cm = sklearn.metrics.confusion_matrix(y_test, y_pred) + + cm = np.random.randint(low=0, high=10, size=[3,4]) x: predicted labels y: true_labels - + + kwargs: "extend_ratio": @@ -36,26 +47,48 @@ def confusion_matrix( in the vertical direction. """ - df = pd.DataFrame(data=cm).copy().T - vmax = np.array(df).max().astype(int) - if (labels is not None) and (pred_labels is None): - df.columns = [mngs.general.to_the_latex_style(l) for l in labels] # pred_labels + if (y_pred_proba is not None) and (y_pred is None): + y_pred = y_pred_proba.argmax(axis=-1) + + assert (cm is not None) or ((y_true is not None) and (y_pred is not None)) + + if not cm: + cm = sklearn_confusion_matrix(y_true, y_pred) + + # Dataframe + df = pd.DataFrame(data=cm).copy() + + # To LaTeX styles + if pred_labels is not None: + pred_labels = [mngs.general.to_the_latex_style(l) for l in pred_labels] + if true_labels is not None: + true_labels = [mngs.general.to_the_latex_style(l) for l in true_labels] + if labels is not None: + labels = [mngs.general.to_the_latex_style(l) for l in labels] + if sorted_labels is not None: + sorted_labels = [ + mngs.general.to_the_latex_style(l) for l in sorted_labels + ] + + # Prediction Labels: columns if pred_labels is not None: - df.columns = [mngs.general.to_the_latex_style(l) for l in pred_labels] + df.columns = pred_labels + elif (pred_labels is None) and (labels is not None): + df.columns = labels - if (labels is not None) and (true_labels is None): - df.index = [mngs.general.to_the_latex_style(l) for l in labels] # true_labels + # Ground Truth Labels: index if true_labels is not None: - df.index = [mngs.general.to_the_latex_style(l) for l in true_labels] - - # # x- and y-ticklabels - # if labels is not None: - - # df.columns = [ - # mngs.general.add_hat_in_the_latex_style(l) for l in labels - # ] # predicted labels - + df.index = true_labels + elif (true_labels is None) and (labels is not None): + df.index = labels + + # Sort based on sorted_labels here + if sorted_labels is not None: + assert set(sorted_labels) == set(labels) + df = df.reindex(index=sorted_labels, columns=sorted_labels) + + # Main fig, ax = plt.subplots() res = sns.heatmap( df, @@ -66,7 +99,7 @@ def confusion_matrix( cbar=False, ) # Here, don't plot color bar. - ## Adds comma separator for the annotated int texts + # Adds comma separator for the annotated int texts for t in ax.texts: t.set_text("{:,d}".format(int(t.get_text()))) @@ -75,50 +108,41 @@ def confusion_matrix( # Makes the frame visible for _, spine in res.spines.items(): - # spine.set_visible(True) - spine.set_visible(False) + spine.set_visible(False) + # Labels ax.set_xlabel("Predicted label") ax.set_ylabel("True label") ax.set_title(title) - ax = mngs.plt.ax_extend(ax, x_extend_ratio, y_extend_ratio) - + # Appearances + ax = mngs.plt.ax.extend(ax, x_extend_ratio, y_extend_ratio) if df.shape[0] == df.shape[1]: ax.set_box_aspect(1) - ax.set_xticklabels( ax.get_xticklabels(), rotation=label_rotation_xy[0], fontdict={"verticalalignment": "top"}, ) - ax.set_yticklabels( ax.get_yticklabels(), rotation=label_rotation_xy[1], fontdict={"horizontalalignment": "right"}, ) - - # The size of the confusion matrix - - # Calculates the dx + # The size bbox = ax.get_position() left_orig = bbox.x0 width_orig = bbox.x1 - bbox.x0 g_x_orig = left_orig + width_orig / 2.0 width_tgt = width_orig * x_extend_ratio # x_extend_ratio dx = width_orig - width_tgt - # print(dx) - """ - The axes objects of the confusion matrix and colorbar are different. - Here, their sizes are adjusted one by one. - """ + # Adjusts the sizes of the confusion matrix and colorbar if colorbar == True: # fixme divider = make_axes_locatable(ax) # Gets region from the ax cax = divider.append_axes("right", size="5%", pad=0.1) # cax = divider.new_horizontal(size="5%", pad=1, pack_start=True) - cax = mngs.plt.ax_set_position(fig, cax, -dx * 2.54, 0) + cax = mngs.plt.ax.set_pos(fig, cax, -dx * 2.54, 0) fig.add_axes(cax) """ @@ -144,6 +168,7 @@ def confusion_matrix( """ # Plots colorbar and adjusts the size + vmax = np.array(df).max().astype(int) norm = matplotlib.colors.Normalize(vmin=0, vmax=vmax) cbar = fig.colorbar( plt.cm.ScalarMappable(norm=norm, cmap="Blues"), @@ -153,9 +178,9 @@ def confusion_matrix( cbar.locator = ticker.MaxNLocator(nbins=4) # tick_locator cbar.update_ticks() # cbar.outline.set_edgecolor("#f9f2d7") - cbar.outline.set_edgecolor("white") + cbar.outline.set_edgecolor("white") - return fig + return fig, cm # def AddAxesBBoxRect(fig, ax, ec="k"): @@ -178,6 +203,13 @@ def confusion_matrix( if __name__ == "__main__": + + import mngs + + y_true, y_pred = mngs.io.load("/tmp/tmp.pkl") + + fig, cm = conf_mat(plt, y_true=y_true, y_pred=y_pred) + # https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py import sys @@ -185,7 +217,8 @@ def confusion_matrix( import numpy as np import sklearn from sklearn import datasets, svm - from sklearn.metrics import plot_confusion_matrix + + # from sklearn.metrics import plot_confusion_matrix from sklearn.model_selection import train_test_split sys.path.append(".") @@ -209,35 +242,35 @@ def confusion_matrix( cm = sklearn.metrics.confusion_matrix(y_test, y_pred) cm **= 3 - cm = np.random.randint(low=0, high=10, size=[3,4]) - + cm = np.random.randint(low=0, high=10, size=[3, 4]) + mngs.plt.configure_mpl( plt, # figsize=(4, 8), - figsize=(4, 8), - fontsize=6, - labelsize=8, - legendfontsize=7, - tick_size=0.8, - tick_width=0.2, + # figsize=(4, 8), + # fontsize=6, + # labelsize=8, + # legendfontsize=7, + # tick_size=0.8, + # tick_width=0.2, ) # labels = class_names pred_labels = ["A", "B", "C"] - true_labels = ["a", "b", "c", "d"] + true_labels = ["a", "b", "c", "d"] - fig = confusion_matrix( + fig, cm = conf_mat( plt, cm, # labels=class_names, pred_labels=pred_labels, true_labels=true_labels, label_rotation_xy=(60, 60), - x_extend_ratio=1., + x_extend_ratio=1.0, colorbar=True, ) - fig.axes[-1] = mngs.plt.ax_scientific_notation( + fig.axes[-1] = mngs.plt.ax.sci_note( fig.axes[-1], 3, fformat="%3.1f", diff --git a/src/mngs/ml/plt/_learning_curve.py b/src/mngs/ml/plt/_learning_curve.py index 5eac562..13ced53 100755 --- a/src/mngs/ml/plt/_learning_curve.py +++ b/src/mngs/ml/plt/_learning_curve.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Time-stamp: "2024-02-02 09:25:07 (ywatanabe)" +# Time-stamp: "2024-03-07 19:50:02 (ywatanabe)" import re import matplotlib import matplotlib.pyplot as plt import mngs +import numpy as np import pandas as pd @@ -33,7 +34,7 @@ def set_yaxis_for_acc(ax, key_plt): return ax -def plot_tra(ax, metrics_df, key_plt, lw=1): +def plot_tra(ax, metrics_df, key_plt, lw=1, color="blue"): indi_step = mngs.gen.search( "^[Tt]rain(ing)?", metrics_df.step, as_bool=True )[0] @@ -44,8 +45,7 @@ def plot_tra(ax, metrics_df, key_plt, lw=1): step_df.index, # i_global step_df[key_plt], label="Training", - # color=COLORS_DICT["tra"], - color=mngs.plt.colors.to_RGBA("blue", alpha=0.9), + color=color, linewidth=lw, ) ax.legend() @@ -53,7 +53,7 @@ def plot_tra(ax, metrics_df, key_plt, lw=1): return ax -def scatter_val(ax, metrics_df, key_plt, s=3): +def scatter_val(ax, metrics_df, key_plt, s=3, color="green"): indi_step = mngs.gen.search( "^[Vv]alid(ation)?", metrics_df.step, as_bool=True )[0] @@ -63,7 +63,7 @@ def scatter_val(ax, metrics_df, key_plt, s=3): step_df.index, step_df[key_plt], label="Validation", - color=mngs.plt.colors.to_RGBA("green", alpha=0.9), + color=color, s=s, alpha=0.9, ) @@ -71,7 +71,7 @@ def scatter_val(ax, metrics_df, key_plt, s=3): return ax -def scatter_tes(ax, metrics_df, key_plt, s=3): +def scatter_tes(ax, metrics_df, key_plt, s=3, color="red"): indi_step = mngs.gen.search("^[Tt]est", metrics_df.step, as_bool=True)[0] step_df = metrics_df[indi_step] if len(step_df) != 0: @@ -79,8 +79,7 @@ def scatter_tes(ax, metrics_df, key_plt, s=3): step_df.index, step_df[key_plt], label="Test", - color=mngs.plt.colors.to_RGBA("red", alpha=0.9), - # color=COLORS_DICT["tes"], + color=color, s=s, alpha=0.9, ) @@ -88,7 +87,7 @@ def scatter_tes(ax, metrics_df, key_plt, s=3): return ax -def vline_at_epochs(ax, metrics_df): +def vline_at_epochs(ax, metrics_df, color="grey"): # Determine the global iteration values where new epochs start epoch_starts = metrics_df[metrics_df["i_batch"] == 0].index.values epoch_labels = metrics_df[metrics_df["i_batch"] == 0].index.values @@ -97,7 +96,7 @@ def vline_at_epochs(ax, metrics_df): ymin=-1e4, # ax.get_ylim()[0], ymax=1e4, # ax.get_ylim()[1], linestyle="--", - color=mngs.plt.colors.to_RGBA("gray", alpha=0.1), + color=color, ) return ax @@ -137,13 +136,34 @@ def learning_curve( linewidth=1, yscale="linear", ): + _plt, cc = mngs.plt.configure_mpl(plt, show=False) + """ + Example: + print(metrics_df) + # step i_global i_epoch i_batch loss + # 0 Training 0 0 0 0.717023 + # 1 Training 1 0 1 0.703844 + # 2 Training 2 0 2 0.696279 + # 3 Training 3 0 3 0.685384 + # 4 Training 4 0 4 0.670675 + # ... ... ... ... ... ... + # 123266 Test 66900 299 866 0.000067 + # 123267 Test 66900 299 867 0.000067 + # 123268 Test 66900 299 868 0.000067 + # 123269 Test 66900 299 869 0.000067 + # 123270 Test 66900 299 870 0.000068 + + # [123271 rows x 5 columns] + """ metrics_df = process_i_global(metrics_df) selected_ticks, selected_labels = select_ticks(metrics_df) - fig, axes = plt.subplots(len(keys_to_plot), 1, sharex=True, sharey=False) + # fig, axes = plt.subplots(len(keys_to_plot), 1, sharex=True, sharey=False) + fig, axes = mngs.plt.subplots( + len(keys_to_plot), 1, sharex=True, sharey=False + ) axes = axes if len(keys_to_plot) != 1 else [axes] - # axes[-1].set_xlabel("Iteration#") axes[-1].set_xlabel("Epoch #") fig.text(0.5, 0.95, title, ha="center") @@ -151,33 +171,31 @@ def learning_curve( ax = axes[i_plt] ax.set_yscale(yscale) ax.set_ylabel(key_plt) - ax = mngs.plt.ax_set_n_ticks(ax) + ax = set_yaxis_for_acc(ax, key_plt) - ax = plot_tra(ax, metrics_df, key_plt, lw=linewidth) - ax = scatter_val(ax, metrics_df, key_plt, s=scattersize) - ax = scatter_tes(ax, metrics_df, key_plt, s=scattersize) - # ax = vline_at_epochs(ax, metrics_df) - - # Custom tick marks - ax = mngs.plt.ax_set_n_ticks(ax) - ax = mngs.plt.ax_map_ticks( - ax, selected_ticks, selected_labels, axis="x" + ax = plot_tra(ax, metrics_df, key_plt, lw=linewidth, color=cc["blue"]) + ax = scatter_val( + ax, metrics_df, key_plt, s=scattersize, color=cc["green"] + ) + ax = scatter_tes( + ax, metrics_df, key_plt, s=scattersize, color=cc["red"] ) - # # ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=max_n_ticks)) - # # ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=max_n_ticks)) - # # Set custom tick positions and labels to reflect selected epoch starts - # ax.set_xticks(selected_ticks) - # ax.set_xticklabels(selected_labels, rotation=45, ha="right") + # # Custom tick marks + # ax = mngs.plt.ax.map_ticks( + # ax, selected_ticks, selected_labels, axis="x" + # ) return fig if __name__ == "__main__": - lpath = "./scripts/train_EEGPT/2024-01-29-12-04_eDflsnWv_v8/metrics.csv" + plt, cc = mngs.plt.configure_mpl(plt) + # lpath = "./scripts/ml/.old/pretrain_EEGPT_old/2024-01-29-12-04_eDflsnWv_v8/metrics.csv" + lpath = "./scripts/ml/pretrain_EEGPT/[DEBUG] 2024-02-11-06-45_4uUpdfpb/metrics.csv" + sdir, _, _ = mngs.gen.split_fpath(lpath) - # sdir = "./scripts/train_EEGPT/[DEBUG] 2024-01-29-07-27_A5HS3f0e/" metrics_df = mngs.io.load(lpath) fig = learning_curve( metrics_df, title="Pretraining on db_v8", yscale="log" diff --git a/src/mngs/ml/sk/__init__.py b/src/mngs/ml/sk/__init__.py new file mode 100755 index 0000000..ee7cc10 --- /dev/null +++ b/src/mngs/ml/sk/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-05 13:17:17 (ywatanabe)" + +from ._clf import * +from ._to_sktime import to_sktime_df diff --git a/src/mngs/ml/sk/_clf.py b/src/mngs/ml/sk/_clf.py new file mode 100755 index 0000000..3f33668 --- /dev/null +++ b/src/mngs/ml/sk/_clf.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-05 16:27:11 (ywatanabe)" + +import numpy as np +from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier +from sklearn.linear_model import RidgeClassifierCV +from sklearn.pipeline import make_pipeline +from sklearn.svm import SVC +from sktime.classification.deep_learning.cnn import CNNClassifier +from sktime.classification.deep_learning.inceptiontime import ( + InceptionTimeClassifier, +) +from sktime.classification.deep_learning.lstmfcn import LSTMFCNClassifier +from sktime.classification.dummy import DummyClassifier +from sktime.classification.feature_based import TSFreshClassifier +from sktime.classification.hybrid import HIVECOTEV2 +from sktime.classification.interval_based import TimeSeriesForestClassifier +from sktime.classification.kernel_based import RocketClassifier, TimeSeriesSVC +from sktime.transformations.panel.reduce import Tabularizer +from sktime.transformations.panel.rocket import Rocket + + +# rocket_pipeline = make_pipeline( +# Rocket(n_jobs=-1), +# RidgeClassifierCV(alphas=np.logspace(-3, 3, 10)), +# ) +def rocket_pipeline(*args, **kwargs): + return make_pipeline( + # Rocket(n_jobs=-1), + Rocket(*args, **kwargs), + SVC(probability=True, kernel="linear"), + ) + + +GB_pipeline = make_pipeline( + Tabularizer(), + GradientBoostingClassifier(), +) diff --git a/src/mngs/ml/sk/_to_sktime.py b/src/mngs/ml/sk/_to_sktime.py new file mode 100755 index 0000000..97cac4b --- /dev/null +++ b/src/mngs/ml/sk/_to_sktime.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-05 13:17:04 (ywatanabe)" + +# import warnings + +import numpy as np +import pandas as pd +import torch + + +def to_sktime_df(X): + """ + Converts a dataset to a format compatible with sktime, encapsulating each sample as a pandas DataFrame. + + Arguments: + - X (numpy.ndarray or torch.Tensor or pandas.DataFrame): The input dataset with shape (n_samples, n_chs, seq_len). + It should be a 3D array-like structure containing the time series data. + + Return: + - sktime_df (pandas.DataFrame): A DataFrame where each element is a pandas Series representing a univariate time series. + + Data Types and Shapes: + - If X is a numpy.ndarray, it should have the shape (n_samples, n_chs, seq_len). + - If X is a torch.Tensor, it should have the shape (n_samples, n_chs, seq_len) and will be converted to a numpy array. + - If X is a pandas.DataFrame, it is assumed to already be in the correct format and will be returned as is. + + References: + - sktime: https://github.com/alan-turing-institute/sktime + + Examples: + -------- + >>> X_np = np.random.rand(64, 160, 1024) + >>> sktime_df = to_sktime_df(X_np) + >>> type(sktime_df) + + """ + if isinstance(X, pd.DataFrame): + return X + elif torch.is_tensor(X): + X = X.numpy() + elif not isinstance(X, np.ndarray): + raise ValueError( + "Input X must be a numpy.ndarray, torch.Tensor, or pandas.DataFrame" + ) + + X = X.astype(np.float64) + + def _format_a_sample_for_sktime(x): + """ + Formats a single sample for sktime compatibility. + + Arguments: + - x (numpy.ndarray): A 2D array with shape (n_chs, seq_len) representing a single sample. + + Return: + - dims (pandas.Series): A Series where each element is a pandas Series representing a univariate time series. + """ + return pd.Series( + [pd.Series(x[d], name=f"dim_{d}") for d in range(x.shape[0])] + ) + + sktime_df = pd.DataFrame( + [_format_a_sample_for_sktime(X[i]) for i in range(X.shape[0])] + ) + return sktime_df + + +# # Obsolete warning for future compatibility +# def to_sktime(*args, **kwargs): +# warnings.warn( +# "to_sktime is deprecated; use to_sktime_df instead.", FutureWarning +# ) +# return to_sktime_df(*args, **kwargs) + + +# import pandas as pd +# import numpy as np +# import torch + +# def to_sktime(X): +# """ +# X.shape: (n_samples, n_chs, seq_len) +# """ + +# def _format_a_sample_for_sktime(x): +# """ +# x.shape: (n_chs, seq_len) +# """ +# dims = pd.Series( +# [pd.Series(x[d], name=f"dim_{d}") for d in range(len(x))], +# index=[f"dim_{i}" for i in np.arange(len(x))], +# ) +# return dims + +# if torch.is_tensor(X): +# X = X.numpy() +# X = X.astype(np.float64) + +# return pd.DataFrame( +# [_format_a_sample_for_sktime(X[i]) for i in range(len(X))] +# ) diff --git a/src/mngs/ml/utils/_LabelEncoder.py b/src/mngs/ml/utils/_LabelEncoder.py new file mode 100755 index 0000000..76d0df9 --- /dev/null +++ b/src/mngs/ml/utils/_LabelEncoder.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-02 09:52:28 (ywatanabe)" + +from warnings import warn + +import numpy as np +import pandas as pd +import torch +from sklearn.preprocessing import LabelEncoder as SklearnLabelEncoder + + +class LabelEncoder(SklearnLabelEncoder): + """ + An extension of the sklearn.preprocessing.LabelEncoder that supports incremental learning. + This means it can handle new classes without forgetting the old ones. + + Attributes: + classes_ (np.ndarray): Holds the label for each class. + + Example usage: + encoder = IncrementalLabelEncoder() + encoder.fit(np.array(["apple", "banana"])) + encoded_labels = encoder.transform(["apple", "banana"]) # This will give you the encoded labels + + encoder.fit(["cherry"]) # Incrementally add "cherry" + encoder.transform(["apple", "banana", "cherry"]) # Now it works, including "cherry" + + # Now you can use inverse_transform with the encoded labels + print(encoder.classes_) + original_labels = encoder.inverse_transform(encoded_labels) + print(original_labels) # This should print ['apple', 'banana'] + """ + + def __init__(self): + super().__init__() + self.classes_ = np.array([]) + + def _check_input(self, y): + """ + Check and convert the input to a NumPy array if it is a list, tuple, pandas.Series, pandas.DataFrame, or torch.Tensor. + + Arguments: + y (list, tuple, pd.Series, pd.DataFrame, torch.Tensor): The input labels. + + Returns: + np.ndarray: The input labels converted to a NumPy array. + """ + if isinstance(y, (list, tuple)): + y = np.array(y) + elif isinstance(y, pd.Series): + y = y.values + elif isinstance(y, torch.Tensor): + y = y.numpy() + return y + + def fit(self, y): + """ + Fit the label encoder with an array of labels, incrementally adding new classes. + + Arguments: + y (list, tuple, np.ndarray, pd.Series, pd.DataFrame, torch.Tensor): The input labels. + + Returns: + IncrementalLabelEncoder: The instance itself. + """ + y = self._check_input(y) + new_unique_labels = np.unique(y) + unique_labels = np.unique( + np.concatenate((self.classes_, new_unique_labels)) + ) + self.classes_ = unique_labels + return self + + def transform(self, y): + """ + Transform labels to normalized encoding. + + Arguments: + y (list, tuple, np.ndarray, pd.Series, pd.DataFrame, torch.Tensor): The input labels. + + Returns: + np.ndarray: The encoded labels as a NumPy array. + + Raises: + ValueError: If the input contains new labels that haven't been seen during `fit`. + """ + + y = self._check_input(y) + diff = set(y) - set(self.classes_) + if diff: + raise ValueError(f"y contains new labels: {diff}") + return super().transform(y) + + def inverse_transform(self, y): + """ + Transform labels back to original encoding. + + Arguments: + y (np.ndarray): The encoded labels as a NumPy array. + + Returns: + np.ndarray: The original labels as a NumPy array. + """ + + return super().inverse_transform(y) + + +# # Obsolete warning for future compatibility +# class LabelEncoder(IncrementalLabelEncoder): +# def __init__(self, *args, **kwargs): +# """ +# Initialize the LabelEncoder with a deprecation warning. +# """ +# warn( +# "LabelEncoder is now obsolete; use IncrementalLabelEncoder instead.", +# category=FutureWarning, +# ) +# super().__init__(*args, **kwargs) + + +if __name__ == "__main__": + # Example usage of IncrementalLabelEncoder + le = LabelEncoder() + le.fit(["A", "B"]) + print(le.classes_) + + le.fit(["C"]) + print(le.classes_) + + le.inverse_transform([0, 1, 2]) + + le.fit(["X"]) + print(le.classes_) + + le.inverse_transform([3]) diff --git a/src/mngs/ml/utils/__init__.py b/src/mngs/ml/utils/__init__.py index fa7e386..988a419 100644 --- a/src/mngs/ml/utils/__init__.py +++ b/src/mngs/ml/utils/__init__.py @@ -1,6 +1,7 @@ +from ._check_params import check_params from ._DefaultDataset import DefaultDataset from ._format_samples_for_sktime import format_samples_for_sktime -from ._get_params import get_params +from ._LabelEncoder import LabelEncoder from ._merge_labels import merge_labels from ._sliding_window_data_augmentation import sliding_window_data_augmentation from ._under_sample import under_sample diff --git a/src/mngs/ml/utils/_check_params.py b/src/mngs/ml/utils/_check_params.py new file mode 100755 index 0000000..1eda2e0 --- /dev/null +++ b/src/mngs/ml/utils/_check_params.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Time-stamp: "2024-02-17 12:38:40 (ywatanabe)" + +from pprint import pprint +from time import sleep + +# def get_params(model, tgt_name=None, sleep_sec=2, show=False): + +# name_shape_dict = {} +# for name, param in model.named_parameters(): +# learnable = "Learnable" if param.requires_grad else "Freezed" + +# if (tgt_name is not None) & (name == tgt_name): +# return param +# if tgt_name is None: +# # print(f"\n{param}\n{param.shape}\nname: {name}\n") +# if show is True: +# print( +# f"\n{param}: {param.shape}\nname: {name}\nStatus: {learnable}\n" +# ) +# sleep(sleep_sec) +# name_shape_dict[name] = list(param.shape) + +# if tgt_name is None: +# print() +# pprint(name_shape_dict) +# print() + + +def check_params(model, tgt_name=None, show=False): + + out_dict = {} + + for name, param in model.named_parameters(): + learnable = "Learnable" if param.requires_grad else "Freezed" + + if tgt_name is None: + out_dict[name] = (param.shape, learnable) + + elif (tgt_name is not None) & (name == tgt_name): + out_dict[name] = (param.shape, learnable) + + elif (tgt_name is not None) & (name != tgt_name): + continue + + if show: + for k, v in out_dict.items(): + print(f"\n{k}\n{v}") + + return out_dict diff --git a/src/mngs/ml/utils/_get_params.py b/src/mngs/ml/utils/_get_params.py deleted file mode 100755 index bbdeb95..0000000 --- a/src/mngs/ml/utils/_get_params.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2021-11-30 11:11:01 (ylab)" - -from time import sleep -from pprint import pprint - - -def get_params(model, tgt_name=None, sleep_sec=2, show=False): - - name_shape_dict = {} - for name, param in model.named_parameters(): - if (tgt_name is not None) & (name == tgt_name): - return param - if tgt_name is None: - # print(f"\n{param}\n{param.shape}\nname: {name}\n") - if show is True: - print(f"\n{param}: {param.shape}\nname: {name}\n") - sleep(sleep_sec) - name_shape_dict[name] = list(param.shape) - - if tgt_name is None: - print() - pprint(name_shape_dict) - print() diff --git a/src/mngs/ml/utils/_merge_labels.py b/src/mngs/ml/utils/_merge_labels.py index d7166b7..3677c42 100755 --- a/src/mngs/ml/utils/_merge_labels.py +++ b/src/mngs/ml/utils/_merge_labels.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -import numpy as np import mngs +import numpy as np # y1, y2 = T_tra, M_tra # def merge_labels(y1, y2): @@ -12,8 +12,11 @@ def merge_labels(*ys, to_int=False): - y = [mngs.general.connect_nums(zs) for zs in zip(*ys)] - if to_int: - conv_d = {z: i for i, z in enumerate(np.unique(y))} - y = [conv_d[z] for z in y] - return np.array(y) + if not len(ys) > 1: # Check if more than two arguments are passed + return ys[0] + else: + y = [mngs.general.connect_nums(zs) for zs in zip(*ys)] + if to_int: + conv_d = {z: i for i, z in enumerate(np.unique(y))} + y = [conv_d[z] for z in y] + return np.array(y) diff --git a/src/mngs/os/__init__.py b/src/mngs/os/__init__.py new file mode 100755 index 0000000..399b6d8 --- /dev/null +++ b/src/mngs/os/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-01 10:25:41 (ywatanabe)" + +from ._mv import mv diff --git a/src/mngs/os/_mv.py b/src/mngs/os/_mv.py new file mode 100755 index 0000000..c9a0213 --- /dev/null +++ b/src/mngs/os/_mv.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-02 21:35:41 (ywatanabe)" + +import os +import shutil + +# def mv(src, tgt): +# successful = True +# os.makedirs(tgt, exist_ok=True) + +# if os.path.isdir(src): +# # Iterate over the items in the directory +# for item in os.listdir(src): +# item_path = os.path.join(src, item) +# # Check if the item is a file +# if os.path.isfile(item_path): +# try: +# shutil.move(item_path, tgt) +# print(f"\nMoved file from {item_path} to {tgt}") +# except OSError as e: +# print(f"\nError: {e}") +# successful = False +# else: +# print(f"\nSkipped directory {item_path}") +# else: +# # If src is a file, just move it +# try: +# shutil.move(src, tgt) +# print(f"\nMoved from {src} to {tgt}") +# except OSError as e: +# print(f"\nError: {e}") +# successful = False + +# return successful + + +def mv(src, tgt): + successful = True + os.makedirs(tgt, exist_ok=True) + + try: + shutil.move(src, tgt) + print(f"\nMoved from {src} to {tgt}") + except OSError as e: + print(f"\nError: {e}") + successful = False diff --git a/src/mngs/plt/_configure_mpl.py b/src/mngs/plt/_configure_mpl.py index 24b58dc..c5b2a38 100755 --- a/src/mngs/plt/_configure_mpl.py +++ b/src/mngs/plt/_configure_mpl.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Time-stamp: "2024-02-03 15:15:48 (ywatanabe)" +# Time-stamp: "2024-03-08 02:07:41 (ywatanabe)" import matplotlib.pyplot as plt import numpy as np @@ -112,6 +112,8 @@ def configure_mpl( "brown": (128, 0, 0, alpha), "darkblue": (0, 0, 100, alpha), "orange": (228, 94, 50, alpha), + "white": (255, 255, 255, alpha), + "black": (0, 0, 0, alpha), } COLORS_HEX = {k: rgba_to_hex(v) for k, v in COLORS_RGBA.items()} COLORS_RGBA_NORM = {c: normalize_rgba(v) for c, v in COLORS_RGBA.items()} diff --git a/src/mngs/plt/_subplots.py b/src/mngs/plt/_subplots.py index 6a90496..b8abe09 100755 --- a/src/mngs/plt/_subplots.py +++ b/src/mngs/plt/_subplots.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Time-stamp: "2024-02-03 16:09:19 (ywatanabe)" +# Time-stamp: "2024-02-04 13:05:40 (ywatanabe)" from collections import OrderedDict diff --git a/src/mngs/plt/ax/__init__.py b/src/mngs/plt/ax/__init__.py index 47eb75f..cdeb91c 100755 --- a/src/mngs/plt/ax/__init__.py +++ b/src/mngs/plt/ax/__init__.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Time-stamp: "2024-02-03 15:36:50 (ywatanabe)" +# Time-stamp: "2024-03-07 23:38:32 (ywatanabe)" from ._circular_hist import circular_hist from ._extend import extend from ._fill_between import fill_between +from ._hide_spines import hide_spines from ._map_ticks import map_ticks from ._panel import panel from ._sci_note import sci_note diff --git a/src/mngs/plt/ax/_hide_spines.py b/src/mngs/plt/ax/_hide_spines.py new file mode 100755 index 0000000..1b38c1b --- /dev/null +++ b/src/mngs/plt/ax/_hide_spines.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Time-stamp: "2024-03-07 23:44:26 (ywatanabe)" + + +def hide_spines(ax, tgts=["top", "right", "bottom", "left"]): + """ + Hides specified spines from a matplotlib Axes. + + Parameters: + - ax (matplotlib.axes.Axes): The Axes object to modify. + - tgts (list of str): List of spines to hide from the Axes. + + Returns: + - ax (matplotlib.axes.Axes): The modified Axes object with spines removed. + """ + for tgt in tgts: + ax.spines[tgt].set_visible(False) + + return ax diff --git a/src/mngs/plt/ax/_set_pos.py b/src/mngs/plt/ax/_set_pos.py index 3021e2e..4592f1a 100755 --- a/src/mngs/plt/ax/_set_pos.py +++ b/src/mngs/plt/ax/_set_pos.py @@ -2,30 +2,39 @@ def set_pos( - fig, - ax, - x_delta_offset_cm, - y_delta_offset_cm, - dragh=False, - dragv=False, + fig, ax, x_delta_offset_cm, y_delta_offset_cm, dragh=False, dragv=False ): + """ + Adjusts the position of an Axes object within a Figure by a specified offset in centimeters. + + Parameters: + - fig (matplotlib.figure.Figure): The Figure object containing the Axes. + - ax (matplotlib.axes.Axes): The Axes object to modify. + - x_delta_offset_cm (float): The horizontal offset in centimeters to adjust the Axes position. + - y_delta_offset_cm (float): The vertical offset in centimeters to adjust the Axes position. + - dragh (bool): If True, reduces the width of the Axes by the horizontal offset. + - dragv (bool): If True, reduces the height of the Axes by the vertical offset. + + Returns: + - ax (matplotlib.axes.Axes): The modified Axes object with the adjusted position. + """ bbox = ax.get_position() - ## Calculates delta ratios + # Calculates delta ratios fig_width_inch, fig_height_inch = fig.get_size_inches() - x_delta_offset_inch = float(x_delta_offset_cm) / 2.54 - y_delta_offset_inch = float(y_delta_offset_cm) / 2.54 + x_delta_offset_inch = x_delta_offset_cm / 2.54 + y_delta_offset_inch = y_delta_offset_cm / 2.54 x_delta_offset_ratio = x_delta_offset_inch / fig_width_inch - y_delta_offset_ratio = y_delta_offset_inch / fig_width_inch + y_delta_offset_ratio = y_delta_offset_inch / fig_height_inch # [REVISED] - ## Determines updated bbox position + # Determines updated bbox position left = bbox.x0 + x_delta_offset_ratio bottom = bbox.y0 + y_delta_offset_ratio - width = bbox.x1 - bbox.x0 - height = bbox.y1 - bbox.y0 + width = bbox.width # [REVISED] + height = bbox.height # [REVISED] if dragh: width -= x_delta_offset_ratio @@ -33,13 +42,50 @@ def set_pos( if dragv: height -= y_delta_offset_ratio - ax.set_pos( - [ - left, - bottom, - width, - height, - ] - ) + ax.set_position([left, bottom, width, height]) # [REVISED] return ax + + +# def set_pos( +# fig, +# ax, +# x_delta_offset_cm, +# y_delta_offset_cm, +# dragh=False, +# dragv=False, +# ): + +# bbox = ax.get_position() + +# ## Calculates delta ratios +# fig_width_inch, fig_height_inch = fig.get_size_inches() + +# x_delta_offset_inch = float(x_delta_offset_cm) / 2.54 +# y_delta_offset_inch = float(y_delta_offset_cm) / 2.54 + +# x_delta_offset_ratio = x_delta_offset_inch / fig_width_inch +# y_delta_offset_ratio = y_delta_offset_inch / fig_width_inch + +# ## Determines updated bbox position +# left = bbox.x0 + x_delta_offset_ratio +# bottom = bbox.y0 + y_delta_offset_ratio +# width = bbox.x1 - bbox.x0 +# height = bbox.y1 - bbox.y0 + +# if dragh: +# width -= x_delta_offset_ratio + +# if dragv: +# height -= y_delta_offset_ratio + +# ax.set_pos( +# [ +# left, +# bottom, +# width, +# height, +# ] +# ) + +# return ax