diff --git a/astro_plasma/core/datasift.py b/astro_plasma/core/datasift.py index 35d2ac4..920df5a 100644 --- a/astro_plasma/core/datasift.py +++ b/astro_plasma/core/datasift.py @@ -9,16 +9,22 @@ import numpy as np from itertools import product from pathlib import Path -from typing import Callable, Optional, Union, Tuple, List, Set +from typing import Protocol, Callable, Optional, Union, Tuple, List, Set import h5py import sys _warn = False +class inherited_from_DataSift(Protocol): + # This must be compulsorily implemented by any class inheriting from DataSift + _check_and_download: Callable + + class DataSift(ABC): def __init__( self: "DataSift", + child_obj: "inherited_from_DataSift", data: h5py.File, ) -> None: """ @@ -36,7 +42,7 @@ def __init__( self.batch_size = np.prod(np.array(data["header/batch_dim"])) self.total_size = np.prod(np.array(data["header/total_size"])) - self._check_and_download: Optional[Callable] = None # mandatory to be implemented by inheriting classes + self._check_and_download = child_obj._check_and_download def _identify_batch(self: "DataSift", i: int, j: int, k: int, m: int) -> int: batch_id = self._get_counter(i, j, k, m) // self.batch_size @@ -450,9 +456,8 @@ def _interpolate( _argument.append(argument_collection[arg_pos][0]) else: _argument.append(argument_collection[arg_pos][indx]) - if not (_dummy) and _array_argument[arg_pos]: - _reference = np.zeros_like(argument_collection[arg_pos]).reshape(_input_shape) - i_vals, j_vals, k_vals, m_vals = self._identify_pos_in_each_dim(*_argument) + nH_this, temperature_this, metallicity_this, redshift_this = _argument + i_vals, j_vals, k_vals, m_vals = self._identify_pos_in_each_dim(nH_this, temperature_this, metallicity_this, redshift_this) """ The trick is to take the floor value for interpolation only if it is the diff --git a/astro_plasma/core/ionization.py b/astro_plasma/core/ionization.py index c25368f..8414bd8 100644 --- a/astro_plasma/core/ionization.py +++ b/astro_plasma/core/ionization.py @@ -39,10 +39,10 @@ def __init__( None. """ + self._check_and_download = download_ionization_data self.base_url_template = BASE_URL_TEMPLATE self.file_name_template = FILE_NAME_TEMPLATE self.base_dir = base_dir - self._check_and_download = download_ionization_data @property def base_dir(self): @@ -57,23 +57,9 @@ def base_dir( fetch(urls=DOWNLOAD_IN_INIT, base_dir=self._base_dir) data = h5py.File(self._base_dir / DOWNLOAD_IN_INIT[0][1], "r") - super().__init__(data) + super().__init__(self, data) data.close() - """ - def _fetch_data(self: "Ionization", batch_ids: Set[int]) -> None: - urls = [] - for batch_id in batch_ids: - urls.append( - ( - self.base_url_template.format(batch_id), - Path(self.file_name_template.format(batch_id)), - ) - ) - - fetch(urls=urls, base_dir=self.base_dir) - """ - def _get_file_path(self: "Ionization", batch_id: int) -> Path: return self.base_dir / self.file_name_template.format(batch_id) diff --git a/astro_plasma/core/spectrum.py b/astro_plasma/core/spectrum.py index 941a269..1956e5d 100644 --- a/astro_plasma/core/spectrum.py +++ b/astro_plasma/core/spectrum.py @@ -41,10 +41,10 @@ def __init__( None. """ + self._check_and_download = download_emission_data self.base_url_template = BASE_URL_TEMPLATE self.file_name_template = FILE_NAME_TEMPLATE self.base_dir = base_dir - self._check_and_download = download_emission_data @property def base_dir(self): @@ -59,24 +59,10 @@ def base_dir( fetch(urls=DOWNLOAD_IN_INIT, base_dir=self._base_dir) data = h5py.File(self._base_dir / DOWNLOAD_IN_INIT[0][1], "r") - super().__init__(data) + super().__init__(self, data) self._energy = np.array(data["output/energy"]) data.close() - """ - def _fetch_data(self: "EmissionSpectrum", batch_ids: Set[int]) -> None: - urls = [] - for batch_id in batch_ids: - urls.append( - ( - self.base_url_template.format(batch_id), - Path(self.file_name_template.format(batch_id)), - ) - ) - - fetch(urls=urls, base_dir=self.base_dir) - """ - def _get_file_path(self: "EmissionSpectrum", batch_id: int) -> Path: return self.base_dir / self.file_name_template.format(batch_id)