Skip to content

Commit

Permalink
Merge pull request #17 from ywatanabe1989/develop
Browse files Browse the repository at this point in the history
v1.1.0
  • Loading branch information
ywatanabe1989 authored Mar 8, 2024
2 parents 0a395b6 + 1e31328 commit 0aa1e3d
Show file tree
Hide file tree
Showing 39 changed files with 1,780 additions and 328 deletions.
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,10 @@ wheel
pybids
scikit-image
icecream
Pyarrow
ruamel.yaml
pytest
pytest-cov
pytest-xdist
pytest-env
umap-learn
8 changes: 4 additions & 4 deletions src/mngs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#!/usr/bin/env python3
# Time-stamp: "2024-02-03 16:06:22 (ywatanabe)"
# Time-stamp: "2024-03-08 20:48:06 (ywatanabe)"

from . import dsp
from . import general
from . import general as gen
from . import gists, io, linalg, ml, nn, plt, resource, stats
from . import gists, io, linalg, ml, nn, os, plt, resource, stats
from .general.debug import *

__copyright__ = "Copyright (C) 2021 Yusuke Watanabe"
__version__ = "1.0.1"
__copyright__ = "Copyright (C) 2024 Yusuke Watanabe"
__version__ = "1.1.0"
__license__ = "GPL3.0"
__author__ = "ywatanabe1989"
__author_email__ = "[email protected]"
Expand Down
1 change: 1 addition & 0 deletions src/mngs/dsp/.#fft.py
1 change: 1 addition & 0 deletions src/mngs/dsp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

# from .wavelet import wavelet
from ._psd import psd_torch
from .demo_sig import demo_sig_np, demo_sig_torch
from .feature_extractors import (
FeatureExtractorTorch,
Expand Down
51 changes: 51 additions & 0 deletions src/mngs/dsp/_psd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-02-21 08:58:35 (ywatanabe)"

import torch


def psd_torch(signal, samp_rate):
"""
# Example usage:
samp_rate = 480 # Sampling rate in Hz
signal = torch.randn(480) # Example signal with 480 samples
freqs, psd = calculate_psd(signal, samp_rate)
# Plot the PSD (if you have matplotlib installed)
import matplotlib.pyplot as plt
plt.plot(freqs.numpy(), psd.numpy())
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power/Frequency (V^2/Hz)')
plt.title('Power Spectral Density')
plt.show()
"""
# Apply window function to the signal (e.g., Hanning window)
window = torch.hann_window(signal.size(-1))
signal = signal * window

# Perform the FFT
fft_output = torch.fft.fft(signal)

# Compute the power spectrum (magnitude squared of the FFT output)
power_spectrum = torch.abs(fft_output) ** 2

# Normalize the power spectrum to get the PSD
# Usually, we divide by the length of the signal and the sum of the window squared
# to get the power in terms of physical units (e.g., V^2/Hz)
psd = power_spectrum / (samp_rate * (window**2).sum())

# Since the signal is real, we only need the positive half of the FFT output
# The factor of 2 accounts for the energy in the negative frequencies that we're discarding
psd = psd[: len(psd) // 2] * 2

# Adjust the DC component (0 Hz) and Nyquist component (if applicable)
psd[0] /= 2
if len(psd) % 2 == 0: # Even length, Nyquist freq component is included
psd[-1] /= 2

# Frequency axis
freqs = torch.fft.fftfreq(signal.size(-1), 1 / samp_rate)[: len(psd)]

return freqs, psd
57 changes: 51 additions & 6 deletions src/mngs/dsp/demo_sig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-01-20 13:58:43 (ywatanabe)"#!/usr/bin/env python3
# Time-stamp: "2024-02-14 20:04:18 (ywatanabe)"#!/usr/bin/env python3

import numpy as np
import torch
Expand Down Expand Up @@ -53,14 +53,59 @@ def demo_sig_np(
return data


# def demo_sig_torch(
# batch_size=64, n_chs=19, samp_rate=1000, len_sec=10, freqs_hz=[2, 5, 10]
# ):
# time = torch.arange(0, len_sec, 1 / samp_rate)
# sig = torch.vstack(
# [torch.sin(f * 2 * torch.pi * time) for f in freqs_hz]
# ).sum(dim=0)
# sig = sig.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_chs, 1)
# return sig


def demo_sig_torch(
batch_size=64, n_chs=19, samp_rate=1000, len_sec=10, freqs_hz=[2, 5, 10]
batch_size=64, n_chs=19, samp_rate=128, len_sec=1, freqs_hz=[2, 5, 10]
):
"""
Generate a batch of demo signals with varying phase shifts and amplitudes,
and add some noise to simulate more realistic data.
Parameters:
- batch_size: Number of samples in the batch.
- n_chs: Number of channels per sample.
- samp_rate: Sampling rate in Hz.
- len_sec: Length of the signal in seconds.
- freqs_hz: List of frequencies in Hz to include in the signal.
Returns:
- sig: Tensor of shape (batch_size, n_chs, samp_rate * len_sec) containing the generated signals.
"""
time = torch.arange(0, len_sec, 1 / samp_rate)
sig = torch.vstack(
[torch.sin(f * 2 * torch.pi * time) for f in freqs_hz]
).sum(dim=0)
sig = sig.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_chs, 1)
# Initialize an empty signal tensor
sig = torch.zeros(batch_size, n_chs, len(time))

# Loop over each frequency and add a sinusoid with a random phase shift and amplitude
for f in freqs_hz:
# Generate a random phase shift for each sample in the batch
phase_shifts = torch.rand(batch_size, 1, 1) * 2 * torch.pi
# Generate a random amplitude for each sample in the batch
amplitudes = (
torch.rand(batch_size, 1, 1) * 0.5 + 0.5
) # Amplitudes between 0.5 and 1.0
# Create the sinusoid with the phase shift and amplitude
sinusoid = amplitudes * torch.sin(
f * 2 * torch.pi * time + phase_shifts
)
# Repeat the sinusoid across channels
sinusoid = sinusoid.repeat(1, n_chs, 1)
# Add the sinusoid to the signal
sig += sinusoid

# Add some Gaussian noise to the signal
noise = torch.randn_like(sig) * 0.1 # Noise level can be adjusted
sig += noise

return sig


Expand Down
3 changes: 3 additions & 0 deletions src/mngs/general/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@
merge_dicts_wo_overlaps,
partial_at,
pop_keys,
print_block,
search,
squeeze_spaces,
suppress_output,
take_the_closest,
unique,
uq,
wait_key,
)
from .pandas import (
Expand Down
47 changes: 34 additions & 13 deletions src/mngs/general/_close.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-01-29 15:44:06 (ywatanabe)"
# Time-stamp: "2024-03-08 20:13:59 (ywatanabe)"

import os
import shutil
from datetime import datetime, timedelta
from glob import glob
from time import sleep
Expand All @@ -23,27 +25,43 @@ def format_diff_time(diff_time):
return diff_time_str


def close(CONFIG, message=":)", show=True):
def close(CONFIG, message=":)", notify=True, show=True):
try:
end_time = datetime.now()
diff_time = format_diff_time(end_time - CONFIG["START_TIME"])
CONFIG["TimeSpent"] = diff_time
del CONFIG["START_TIME"]
CONFIG["END_TIME"] = datetime.now()
CONFIG["SPENT_TIME"] = format_diff_time(
CONFIG["END_TIME"] - CONFIG["START_TIME"]
)
if show:
print(f"\nEND TIME: {CONFIG['END_TIME']}")
print(f"\nSPENT TIME: {CONFIG['SPENT_TIME']}")

except Exception as e:
print(e)

mngs.io.save(CONFIG, CONFIG["SDIR"] + "CONFIG.pkl")
mngs.io.save(CONFIG, CONFIG["SDIR"] + "CONFIG.yaml")

try:
if CONFIG.get("DEBUG", False):
message = f"[DEBUG]\n" + message
sleep(3)
mngs.gen.notify(
message=message,
ID=CONFIG["ID"],
log_paths=glob(CONFIG["SDIR"] + "*.log"),
show=show,
)
if notify:
mngs.gen.notify(
message=message,
ID=CONFIG["ID"],
log_paths=glob(CONFIG["SDIR"] + "*.log"),
show=show,
)
except Exception as e:
print(e)

# RUNNING to FINISHED
src_dir = CONFIG["SDIR"]
dest_dir = src_dir.replace("RUNNING", "FINISHED")
os.makedirs(dest_dir, exist_ok=True)
try:
os.rename(src_dir, dest_dir)
print(f"\nRenamed from: {src_dir} to {dest_dir}")
except Exception as e:
print(e)

Expand All @@ -53,8 +71,11 @@ def close(CONFIG, message=":)", show=True):

import matplotlib.pyplot as plt
import mngs
from icecream import ic

CONFIG, sys.stdout, sys.stderr, plt = mngs.gen.start(sys, plt, show=False)
CONFIG, sys.stdout, sys.stderr, plt, CC = mngs.gen.start(
sys, plt, show=False
)

ic("aaa")
ic("bbb")
Expand Down
72 changes: 39 additions & 33 deletions src/mngs/general/_start.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-01-29 15:41:58 (ywatanabe)"
# Time-stamp: "2024-03-08 19:55:48 (ywatanabe)"

import inspect
import os as _os
Expand All @@ -25,23 +25,24 @@ def start(
tf=None,
seed=42,
# matplotlib
dpi=100,
save_dpi=300,
figsize=(16.2, 10),
figscale=1.0,
fontsize=16,
labelsize="same",
legendfontsize="xx-small",
tick_size="auto",
tick_width="auto",
hide_spines=False,
fig_size_mm=(160, 100),
fig_scale=1.0,
dpi_display=100,
dpi_save=300,
font_size_base=8,
font_size_title=8,
font_size_axis_label=8,
font_size_tick_label=7,
font_size_legend=6,
hide_top_right_spines=True,
alpha=0.75,
):
"""
import sys
import matplotlib.pyplot as plt
import mngs
CONFIG, sys.stdout, sys.stderr, plt = mngs.gen.start(sys, plt)
CONFIG, sys.stdout, sys.stderr, plt, cc = mngs.gen.start(sys, plt)
# YOUR CODE HERE
Expand All @@ -52,13 +53,15 @@ def start(

# Debug mode check
try:
is_debug = mngs.io.load("./config/is_debug.yaml").get("DEBUG", False)
IS_DEBUG = mngs.io.load("./config/IS_DEBUG.yaml").get(
"IS_DEBUG", False
)
except Exception as e:
is_debug = False
IS_DEBUG = False

# ID
ID = mngs.gen.gen_ID()
ID = ID if not is_debug else "[DEBUG] " + ID
ID = mngs.gen.gen_ID(N=4)
ID = ID if not IS_DEBUG else "DEBUG_" + ID
print(f"\n{'#'*40}\n## {ID}\n{'#'*40}\n")
sleep(1)

Expand All @@ -69,20 +72,22 @@ def start(
__file__ = "/tmp/fake.py"
spath = __file__
_sdir, sfname, _ = mngs.general.split_fpath(spath)
sdir = _sdir + sfname + "/" + ID + "/" # " # + "/log/"
sdir = (
_sdir + sfname + "/" + "RUNNING" + "/" + ID + "/"
) # " # + "/log/"
_os.makedirs(sdir, exist_ok=True)

# CONFIGs
CONFIGS = mngs.io.load_configs(is_debug)
CONFIGS = mngs.io.load_configs(IS_DEBUG)
CONFIGS["ID"] = ID
CONFIGS["START_TIME"] = start_time
CONFIGS["SDIR"] = sdir
CONFIGS["SDIR"] = sdir.replace("/./", "/")
if show:
print(f"\n{'-'*40}\n")
print(f"CONFIG:")
for k, v in CONFIGS.items():
print(f"\n{k}:\n{v}\n")
sleep(0.1)
# sleep(0.1)
print(f"\n{'-'*40}\n")

# Logging (tee)
Expand All @@ -109,22 +114,23 @@ def start(

# Matplotlib configuration
if plt is not None:
plt = mngs.plt.configure_mpl(
plt, cc = mngs.plt.configure_mpl(
plt,
dpi=dpi,
save_dpi=save_dpi,
figsize=figsize,
figscale=figscale,
fontsize=fontsize,
labelsize=labelsize,
legendfontsize=legendfontsize,
tick_size=tick_size,
tick_width=tick_width,
hide_spines=hide_spines,
fig_size_mm=(160, 100),
fig_scale=fig_scale,
dpi_display=dpi_display,
dpi_save=dpi_save,
font_size_base=font_size_base,
font_size_title=font_size_title,
font_size_axis_label=font_size_axis_label,
font_size_tick_label=font_size_tick_label,
font_size_legend=font_size_legend,
hide_top_right_spines=hide_top_right_spines,
alpha=alpha,
show=show,
)

return CONFIGS, sys.stdout, sys.stderr, plt
return CONFIGS, sys.stdout, sys.stderr, plt, cc


if __name__ == "__main__":
Expand All @@ -134,7 +140,7 @@ def start(
import mngs

# --------------------------------------------------------------------------- #
CONFIG, sys.stdout, sys.stderr, plt = mngs.gen.start(sys, plt)
CONFIG, sys.stdout, sys.stderr, plt, cc = mngs.gen.start(sys, plt)
# --------------------------------------------------------------------------- #

# YOUR CODE HERE
Expand Down
Loading

0 comments on commit 0aa1e3d

Please sign in to comment.