Skip to content

Commit

Permalink
fix(pyright): add strict type definitions (#116)
Browse files Browse the repository at this point in the history
* fix(pyright): add some static typing

* fix(pyright): more typing

This commit adds even more typing to the module. Still not done though.

* fix(pyright): finishes typing library

Still missing tests though.

* fix(pyright): final fixes

This commit finally fixes all pyright errors.
  • Loading branch information
serramatutu authored Apr 30, 2024
1 parent 3651173 commit 1ad1ff7
Show file tree
Hide file tree
Showing 20 changed files with 606 additions and 517 deletions.
3 changes: 2 additions & 1 deletion jafgen/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Annotated

import typer

from jafgen.simulation import Simulation
Expand All @@ -15,7 +16,7 @@ def run(
str,
typer.Option(help="Optional prefix for the output file names."),
] = "raw",
):
) -> None:
sim = Simulation(years, pre)
sim.run_simulation()
sim.save_results()
93 changes: 25 additions & 68 deletions jafgen/curves.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import datetime
from abc import ABC, abstractmethod

import numpy as np
import numpy.typing as npt
from typing_extensions import override

NumberArr = npt.NDArray[np.float64] | npt.NDArray[np.int32]



class Curve(ABC):
@property
@abstractmethod
def Domain(self):
def Domain(self) -> NumberArr:
raise NotImplementedError

@abstractmethod
def TranslateDomain(self, date):
def TranslateDomain(self, date: datetime.date) -> int:
raise NotImplementedError

@abstractmethod
def Expr(self, x):
def Expr(self, x: float) -> float:
raise NotImplementedError

@classmethod
def eval(cls, date):
def eval(cls, date: datetime.date) -> float:
instance = cls()
domain_value = instance.TranslateDomain(date)
domain_index = domain_value % len(instance.Domain)
Expand All @@ -28,25 +34,28 @@ def eval(cls, date):

class AnnualCurve(Curve):
@property
def Domain(self):
return np.linspace(0, 2 * np.pi, 365)
@override
def Domain(self) -> NumberArr:
return np.linspace(0, 2 * np.pi, 365, dtype=np.float64)

def TranslateDomain(self, date):
@override
def TranslateDomain(self, date: datetime.date) -> int:
return date.timetuple().tm_yday

def Expr(self, x):
@override
def Expr(self, x: float) -> float:
return (np.cos(x) + 1) / 10 + 0.8


class WeekendCurve(Curve):
@property
def Domain(self):
return tuple(range(6))
def Domain(self) -> NumberArr:
return np.array(range(6), dtype=np.float64)

def TranslateDomain(self, date):
def TranslateDomain(self, date: datetime.date) -> int:
return date.weekday() - 1

def Expr(self, x):
def Expr(self, x: float):
if x >= 6:
return 0.6
else:
Expand All @@ -55,65 +64,13 @@ def Expr(self, x):

class GrowthCurve(Curve):
@property
def Domain(self):
return tuple(range(500))
def Domain(self) -> NumberArr:
return np.arange(500, dtype=np.int32)

def TranslateDomain(self, date):
def TranslateDomain(self, date: datetime.date) -> int:
return (date.year - 2016) * 12 + date.month

def Expr(self, x):
def Expr(self, x: float) -> float:
# ~ aim for ~20% growth/year
return 1 + (x / 12) * 0.2


class Day(object):
EPOCH = datetime.datetime(year=2018, month=9, day=1)
SEASONAL_MONTHLY_CURVE = AnnualCurve()
WEEKEND_CURVE = WeekendCurve()
GROWTH_CURVE = GrowthCurve()

def __init__(self, date_index, minutes=0):
self.date_index = date_index
self.date = self.EPOCH + datetime.timedelta(days=date_index, minutes=minutes)
self.effects = [
self.SEASONAL_MONTHLY_CURVE.eval(self.date),
self.WEEKEND_CURVE.eval(self.date),
self.GROWTH_CURVE.eval(self.date),
]

def at_minute(self, minutes):
return Day(self.date_index, minutes=minutes)

def get_effect(self):
total = 1
for effect in self.effects:
total = total * effect
return total

# weekend_effect = 0.8 if date.is_weekend else 1
# summer_effect = 0.7 if date.season == 'summer' else 1

@property
def day_of_week(self):
return self.date.weekday()

@property
def is_weekend(self):
# 5 + 6 are weekends
return self.date.weekday() >= 5

@property
def season(self):
month_no = self.date.month
day_no = self.date.day

if month_no in (1, 2) or (month_no == 3 and day_no < 21):
return "winter"
elif month_no in (3, 4, 5) or (month_no == 6 and day_no < 21):
return "spring"
elif month_no in (6, 7, 8) or (month_no == 9 and day_no < 21):
return "summer"
elif month_no in (9, 10, 11) or (month_no == 12 and day_no < 21):
return "fall"
else:
return "winter"
Loading

0 comments on commit 1ad1ff7

Please sign in to comment.