Skip to content

Commit

Permalink
feat: add some style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidni committed Aug 7, 2023
1 parent 05b3438 commit be4a092
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 25 deletions.
3 changes: 2 additions & 1 deletion catalog_tools/analysis/estimate_beta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains functions for the estimation of beta and the b-value.
"""
from typing import Optional, Tuple, Union

import numpy as np


Expand Down Expand Up @@ -252,7 +253,7 @@ def shi_bolt_confidence(
"""
# standard deviation in Shi and Bolt is calculated with 1/(N*(N-1)), which
# is by a factor of sqrt(N) different to the std(x, ddof=1) estimator
assert b_value is not None or beta is not None,\
assert b_value is not None or beta is not None, \
'please specify b-value or beta'
assert b_value is None or beta is None, \
'please only specify either b-value or beta'
Expand Down
21 changes: 9 additions & 12 deletions catalog_tools/analysis/tests/test_estimate_beta.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import pytest
import numpy as np
import pytest

# import functions to be tested
from catalog_tools.analysis.estimate_beta import (differences, estimate_b_elst,
estimate_b_laplace,
estimate_b_tinti,
estimate_b_utsu,
estimate_beta_tinti,
shi_bolt_confidence)
from catalog_tools.utils.binning import bin_to_precision
# import functions from other modules
from catalog_tools.utils.simulate_distributions import simulate_magnitudes
from catalog_tools.utils.binning import bin_to_precision

# import functions to be tested
from catalog_tools.analysis.estimate_beta import\
estimate_beta_tinti,\
estimate_b_tinti,\
estimate_b_utsu,\
estimate_b_elst,\
estimate_b_laplace,\
differences,\
shi_bolt_confidence


def simulate_magnitudes_w_offset(n: int, beta: float, mc: float,
Expand Down
15 changes: 9 additions & 6 deletions catalog_tools/download/tests/test_download_catalogs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import datetime as dt
import os
from unittest import mock
import numpy as np

from numpy.testing import assert_equal, assert_allclose, assert_array_less

from catalog_tools.download.download_catalogs import apply_edwards, \
download_catalog_sed, prepare_sed_catalog, download_catalog_1,\
download_catalog_scedc, prepare_scedc_catalog
import numpy as np
from numpy.testing import assert_allclose, assert_array_less, assert_equal

from catalog_tools.download.download_catalogs import (apply_edwards,
download_catalog_1,
download_catalog_scedc,
download_catalog_sed,
prepare_scedc_catalog,
prepare_sed_catalog)


def test_apply_edwards():
Expand Down
9 changes: 5 additions & 4 deletions catalog_tools/plots/basics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Union

# Own functions
from catalog_tools.utils.binning import get_cum_fmd, get_fmd
Expand Down Expand Up @@ -246,9 +247,9 @@ def plot_mags_in_time(
c="b", linewidth=0.5, alpha=0.8, edgecolor='k')

if mc_change_times is not None and mcs is not None:
if not year_only and type(mc_change_times[0]) == int:
if not year_only and isinstance(mc_change_times[0], int):
mc_change_times = [dt.datetime(x, 1, 1) for x in mc_change_times]
if year_only and type(mc_change_times[0]) != int:
if year_only and not isinstance(mc_change_times[0], int):
mc_change_times = [x.year for x in mc_change_times]

mc_change_times.append(np.max(times))
Expand Down
5 changes: 3 additions & 2 deletions catalog_tools/utils/binning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import decimal
from typing import Union

import numpy as np


def normal_round_to_int(x: float) -> int:
"""
Expand Down Expand Up @@ -51,7 +52,7 @@ def bin_to_precision(x: Union[np.ndarray, list], delta_x: float = 0.1
Returns:
Value rounded to the given precision.
"""
if type(x) == list:
if isinstance(x, list):
x = np.array(x)
d = decimal.Decimal(str(delta_x))
decimal_places = abs(d.as_tuple().exponent)
Expand Down

0 comments on commit be4a092

Please sign in to comment.