Skip to content

Commit

Permalink
[FIX] Minor coding style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dutta-alankar committed Dec 9, 2023
1 parent 0780cbc commit 7a6441d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 37 deletions.
15 changes: 10 additions & 5 deletions astro_plasma/core/datasift.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@
import numpy as np
from itertools import product
from pathlib import Path
from typing import Callable, Optional, Union, Tuple, List, Set
from typing import Protocol, Callable, Optional, Union, Tuple, List, Set
import h5py
import sys

_warn = False


class inherited_from_DataSift(Protocol):
# This must be compulsorily implemented by any class inheriting from DataSift
_check_and_download: Callable


class DataSift(ABC):
def __init__(
self: "DataSift",
child_obj: "inherited_from_DataSift",
data: h5py.File,
) -> None:
"""
Expand All @@ -36,7 +42,7 @@ def __init__(

self.batch_size = np.prod(np.array(data["header/batch_dim"]))
self.total_size = np.prod(np.array(data["header/total_size"]))
self._check_and_download: Optional[Callable] = None # mandatory to be implemented by inheriting classes
self._check_and_download = child_obj._check_and_download

def _identify_batch(self: "DataSift", i: int, j: int, k: int, m: int) -> int:
batch_id = self._get_counter(i, j, k, m) // self.batch_size
Expand Down Expand Up @@ -450,9 +456,8 @@ def _interpolate(
_argument.append(argument_collection[arg_pos][0])
else:
_argument.append(argument_collection[arg_pos][indx])
if not (_dummy) and _array_argument[arg_pos]:
_reference = np.zeros_like(argument_collection[arg_pos]).reshape(_input_shape)
i_vals, j_vals, k_vals, m_vals = self._identify_pos_in_each_dim(*_argument)
nH_this, temperature_this, metallicity_this, redshift_this = _argument
i_vals, j_vals, k_vals, m_vals = self._identify_pos_in_each_dim(nH_this, temperature_this, metallicity_this, redshift_this)

"""
The trick is to take the floor value for interpolation only if it is the
Expand Down
18 changes: 2 additions & 16 deletions astro_plasma/core/ionization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def __init__(
None.
"""
self._check_and_download = download_ionization_data
self.base_url_template = BASE_URL_TEMPLATE
self.file_name_template = FILE_NAME_TEMPLATE
self.base_dir = base_dir
self._check_and_download = download_ionization_data

@property
def base_dir(self):
Expand All @@ -57,23 +57,9 @@ def base_dir(

fetch(urls=DOWNLOAD_IN_INIT, base_dir=self._base_dir)
data = h5py.File(self._base_dir / DOWNLOAD_IN_INIT[0][1], "r")
super().__init__(data)
super().__init__(self, data)
data.close()

"""
def _fetch_data(self: "Ionization", batch_ids: Set[int]) -> None:
urls = []
for batch_id in batch_ids:
urls.append(
(
self.base_url_template.format(batch_id),
Path(self.file_name_template.format(batch_id)),
)
)
fetch(urls=urls, base_dir=self.base_dir)
"""

def _get_file_path(self: "Ionization", batch_id: int) -> Path:
return self.base_dir / self.file_name_template.format(batch_id)

Expand Down
18 changes: 2 additions & 16 deletions astro_plasma/core/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(
None.
"""
self._check_and_download = download_emission_data
self.base_url_template = BASE_URL_TEMPLATE
self.file_name_template = FILE_NAME_TEMPLATE
self.base_dir = base_dir
self._check_and_download = download_emission_data

@property
def base_dir(self):
Expand All @@ -59,24 +59,10 @@ def base_dir(

fetch(urls=DOWNLOAD_IN_INIT, base_dir=self._base_dir)
data = h5py.File(self._base_dir / DOWNLOAD_IN_INIT[0][1], "r")
super().__init__(data)
super().__init__(self, data)
self._energy = np.array(data["output/energy"])
data.close()

"""
def _fetch_data(self: "EmissionSpectrum", batch_ids: Set[int]) -> None:
urls = []
for batch_id in batch_ids:
urls.append(
(
self.base_url_template.format(batch_id),
Path(self.file_name_template.format(batch_id)),
)
)
fetch(urls=urls, base_dir=self.base_dir)
"""

def _get_file_path(self: "EmissionSpectrum", batch_id: int) -> Path:
return self.base_dir / self.file_name_template.format(batch_id)

Expand Down

0 comments on commit 7a6441d

Please sign in to comment.