Skip to content

Commit

Permalink
feat: add support for time-based MultiIndex in RateGrid
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidni committed Sep 15, 2023
1 parent 3821c21 commit 6710002
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
67 changes: 67 additions & 0 deletions catalog_tools/rategrid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numpy as np
import pandas as pd

from catalog_tools.utils import _check_required_cols, require_cols
Expand Down Expand Up @@ -66,6 +67,12 @@ def __init__(self, data=None, *args, name=None,
self.endtime = endtime if isinstance(
endtime, pd.Timestamp) else pd.to_datetime(endtime)

if len(self.index.names) > 1:
try:
self.reindex_cell_id()
except AttributeError:
pass

@property
def _constructor(self):
return _rategrid_constructor_with_fallback
Expand All @@ -89,6 +96,66 @@ def strip(self, inplace: bool = False) -> RateGrid | None:
if not inplace:
return df

def add_time_index(self, endtime=True):
"""
Create MultiIndex using starttime, optionally endtime and a cell
number for each spatial block.
Args:
endtime : bool, optional
If True, create MultiIndex with starttime and endtime.
Otherwise, create MultiIndex with only starttime.
Returns:
RateGrid
"""
if not getattr(self, 'starttime', None) or \
not getattr(self, 'endtime', None):
raise AttributeError(
'starttime and endtime must be set to use this method')

index = (self.starttime, self.endtime) if endtime else self.starttime
names = ['starttime', 'endtime'] if endtime else ['starttime']

# rename the index to cell_id, will be set in constructor
self.index.name = 'cell_id'

df = pd.concat({index: self}, names=names)

# manually set the metadata attributes
for arg in self._metadata:
setattr(df, arg, getattr(self, arg))

return df

@require_cols(require=_required_cols)
def reindex_cell_id(self):
"""
If the RateGrid has a MultiIndex which includes `cell_id`
as a level, this method will update the RateGrid's index to use
unique cell_id values.
"""

if 'cell_id' in self.index.names:
cell_bounds = self[['longitude_min', 'longitude_max',
'latitude_min', 'latitude_max',
'depth_min', 'depth_max']]

self['cell'] = np.unique(
cell_bounds, axis=0, return_inverse=True, equal_nan=True)[1]

self.set_index('cell', append=True, drop=True, inplace=True)
self.index = self.index.droplevel('cell_id')

self.index.set_names('cell_id', level='cell', inplace=True)

if 'starttime' in self.index.names:
self.starttime = self.index.get_level_values('starttime').min()
if 'endtime' in self.index.names:
self.endtime = self.index.get_level_values('endtime').max()
else:
self.endtime = self.index.get_level_values('starttime').max()


class ForecastRateGrid(RateGrid):
"""
Expand Down
54 changes: 54 additions & 0 deletions catalog_tools/tests/test_rategrid.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from datetime import datetime

import pandas as pd
import pytest

from catalog_tools.rategrid import (REQUIRED_COLS_RATEGRID, ForecastRateGrid,
RateGrid)

Expand All @@ -15,6 +20,20 @@
'grid_id': [1, 2, 3, 4]
}

RAW_DATA_2 = {
'longitude_min': [-90, 0, 90, -180],
'longitude_max': [0, 90, 180, -90],
'latitude_min': [-45, 0, 45, -90],
'latitude_max': [0, 45, 90, -45],
'depth_min': [10, 20, 30, 0],
'depth_max': [20, 30, 40, 10],
'number_events': [100, 200, 300, 400],
'a': [1.0, 1.5, 2.0, 2.5],
'b': [0.5, 0.6, 0.7, 0.8],
'mc': [4.0, 4.5, 5.0, 5.5],
'grid_id': [1, 2, 3, 4]
}


def test_rategrid_init():
# Test initialization with data
Expand Down Expand Up @@ -66,3 +85,38 @@ def test_forecast_rategrid_strip():
# Test constructor fallback "downgrade"
dropped = rategrid.drop(columns=['grid_id'])
assert isinstance(dropped, RateGrid)


def test_rategrid_time_index():
starttimes = [datetime(2020, 1, 1), datetime(2020, 1, 3)]
endtimes = [datetime(2020, 1, 2), datetime(2020, 1, 4)]

rategrid = ForecastRateGrid(
RAW_DATA, starttime=starttimes[0], endtime=endtimes[0])
rategrid2 = ForecastRateGrid(
RAW_DATA_2, starttime=starttimes[1], endtime=endtimes[1])

rategrid = rategrid.add_time_index(endtime=False)
rategrid2 = rategrid2.add_time_index(endtime=False)

assert rategrid.starttime == starttimes[0]
assert rategrid2.endtime == endtimes[1]

rategrid = pd.concat([rategrid, rategrid2], axis=0, sort=False)

assert list(rategrid.index.get_level_values(
'starttime').unique()) == starttimes

assert list(rategrid.index.get_level_values(
'cell_id')) == [0, 1, 2, 3, 1, 2, 3, 0]

assert rategrid.starttime == starttimes[0]
assert rategrid.endtime == starttimes[1]

assert rategrid2.endtime == endtimes[1]

rategrid_none = ForecastRateGrid(
RAW_DATA)

with pytest.raises(AttributeError):
rategrid_none.add_time_index()

0 comments on commit 6710002

Please sign in to comment.