Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
tronikos committed Aug 16, 2023
1 parent a73b047 commit d983dca
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 69 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand All @@ -30,7 +30,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install .
python -m pip install flake8 pytest ruff
python -m pip install flake8 pytest ruff mypy pydantic
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand All @@ -40,6 +40,9 @@ jobs:
- name: Lint with ruff
run: |
ruff .
- name: Static typing with mypy
run: |
mypy --install-types --non-interactive --no-warn-unused-ignores .
- name: Test with pytest
run: |
pytest
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ python -m pip install flake8 ruff
flake8 .
ruff . --fix

# Run formatter and lint
isort . ; black . ; flake8 . ; ruff . --fix
# Run type checking
python -m pip install mypy pydantic
mypy .

# Run formatter, lint, and type checking
isort . ; black . ; flake8 . ; ruff . --fix ; mypy .

# Run tests
python -m pip install pytest
Expand Down
31 changes: 31 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[mypy]
exclude = (venv|build)
python_version = 3.9
plugins = pydantic.mypy
show_error_codes = true
follow_imports = silent
ignore_missing_imports = true
local_partial_types = true
strict_equality = true
no_implicit_optional = true
warn_incomplete_stub = true
warn_redundant_casts = true
warn_unused_configs = true
warn_unused_ignores = true
enable_error_code = ignore-without-code, redundant-self, truthy-iterable
disable_error_code = annotation-unchecked
extra_checks = false
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true

[pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
warn_untyped_fields = true
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
description = "A Python library for getting historical and forecasted usage/cost from utilities that use opower.com such as PG&E"
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.9"
dependencies = [
"aiohttp>=3.8",
"arrow>=1.2",
Expand Down
6 changes: 3 additions & 3 deletions src/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from datetime import datetime, timedelta
from getpass import getpass
import logging
from typing import Optional

import aiohttp

from opower import AggregateType, Opower, ReadResolution, get_supported_utilities


async def _main():
async def _main() -> None:
supported_utilities = [
utility.__name__.lower()
for utility in get_supported_utilities(supports_mfa=True)
Expand Down Expand Up @@ -101,14 +102,14 @@ async def _main():
"end_date=",
args.end_date,
)
prev_end: Optional[datetime] = None
if args.usage_only:
usage_data = await opower.async_get_usage_reads(
account,
aggregate_type,
args.start_date,
args.end_date,
)
prev_end = None
print(
"start_time\tend_time\tconsumption"
"\tstart_minus_prev_end\tend_minus_prev_end"
Expand All @@ -135,7 +136,6 @@ async def _main():
args.start_date,
args.end_date,
)
prev_end = None
print(
"start_time\tend_time\tconsumption\tprovided_cost"
"\tstart_minus_prev_end\tend_minus_prev_end"
Expand Down
53 changes: 27 additions & 26 deletions src/opower/opower.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
import json
import logging
from typing import Any, Optional
from typing import Any, Optional, Union
from urllib.parse import urlencode

import aiohttp
Expand All @@ -26,7 +26,7 @@ class MeterType(Enum):
ELEC = "ELEC"
GAS = "GAS"

def __str__(self):
def __str__(self) -> str:
"""Return the value of the enum."""
return self.value

Expand All @@ -38,7 +38,7 @@ class UnitOfMeasure(Enum):
THERM = "THERM"
CCF = "CCF"

def __str__(self):
def __str__(self) -> str:
"""Return the value of the enum."""
return self.value

Expand All @@ -53,7 +53,7 @@ class AggregateType(Enum):
# Home Assistant only has hourly data in the energy dashboard and
# some utilities (e.g. PG&E) claim QUARTER_HOUR but they only provide HOUR.

def __str__(self):
def __str__(self) -> str:
"""Return the value of the enum."""
return self.value

Expand All @@ -67,7 +67,7 @@ class ReadResolution(Enum):
HALF_HOUR = "HALF_HOUR"
QUARTER_HOUR = "QUARTER_HOUR"

def __str__(self):
def __str__(self) -> str:
"""Return the value of the enum."""
return self.value

Expand Down Expand Up @@ -144,14 +144,14 @@ class UsageRead:


# TODO: remove supports_mfa and accepts_mfa from all files after ConEd is released to Home Assistant
def get_supported_utilities(supports_mfa=False) -> list[type["UtilityBase"]]:
def get_supported_utilities(supports_mfa: bool = False) -> list[type["UtilityBase"]]:
"""Return a list of all supported utilities."""
return [
cls for cls in UtilityBase.subclasses if supports_mfa or not cls.accepts_mfa()
]


def get_supported_utility_names(supports_mfa=False) -> list[str]:
def get_supported_utility_names(supports_mfa: bool = False) -> list[str]:
"""Return a sorted list of names of all supported utilities."""
return sorted(
[
Expand Down Expand Up @@ -184,13 +184,13 @@ def __init__(
"""Initialize."""
# Note: Do not modify default headers since Home Assistant that uses this library needs to use
# a default session for all integrations. Instead specify the headers for each request.
self.session = session
self.session: aiohttp.ClientSession = session
self.utility: type[UtilityBase] = _select_utility(utility)
self.username = username
self.password = password
self.optional_mfa_secret = optional_mfa_secret
self.access_token = None
self.customers = []
self.username: str = username
self.password: str = password
self.optional_mfa_secret: Optional[str] = optional_mfa_secret
self.access_token: Optional[str] = None
self.customers: list[Any] = []

async def async_login(self) -> None:
"""Login to the utility website and authorize opower.com for access.
Expand Down Expand Up @@ -310,8 +310,8 @@ async def async_get_cost_reads(
self,
account: Account,
aggregate_type: AggregateType,
start_date: datetime | None = None,
end_date: datetime | None = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
usage_only: bool = False,
) -> list[CostRead]:
"""Get usage and cost data for the selected account in the given date range aggregated by bill/day/hour.
Expand All @@ -328,7 +328,9 @@ async def async_get_cost_reads(
CostRead(
start_time=datetime.fromisoformat(read["startTime"]),
end_time=datetime.fromisoformat(read["endTime"]),
consumption=read["value"] if "value" in read else read["consumption"]["value"],
consumption=read["value"]
if "value" in read
else read["consumption"]["value"],
provided_cost=read.get("providedCost", 0) or 0,
)
)
Expand All @@ -351,8 +353,8 @@ async def async_get_usage_reads(
self,
account: Account,
aggregate_type: AggregateType,
start_date: datetime | None = None,
end_date: datetime | None = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> list[UsageRead]:
"""Get usage data for the selected account in the given date range aggregated by bill/day/hour.
Expand All @@ -377,16 +379,15 @@ async def _async_get_dated_data(
self,
account: Account,
aggregate_type: AggregateType,
start_date: datetime | None = None,
end_date: datetime | None = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
usage_only: bool = False,
) -> list[Any]:
"""Wrap _async_fetch by breaking requests for big date ranges to smaller ones to satisfy opower imposed limits."""
# TODO: remove not None check after a Home Assistant release
if (
account.read_resolution is not None
and aggregate_type
not in SUPPORTED_AGGREGATE_TYPES.get(account.read_resolution)
and aggregate_type not in SUPPORTED_AGGREGATE_TYPES[account.read_resolution]
):
raise ValueError(
f"Requested aggregate_type: {aggregate_type} "
Expand Down Expand Up @@ -433,8 +434,8 @@ async def _async_fetch(
self,
account: Account,
aggregate_type: AggregateType,
start_date: datetime | arrow.Arrow | None = None,
end_date: datetime | arrow.Arrow | None = None,
start_date: Union[datetime, arrow.Arrow, None] = None,
end_date: Union[datetime, arrow.Arrow, None] = None,
usage_only: bool = False,
) -> list[Any]:
if usage_only:
Expand Down Expand Up @@ -476,15 +477,15 @@ async def _async_fetch(
result = await resp.json()
if DEBUG_LOG_RESPONSE:
_LOGGER.debug("Fetched: %s", json.dumps(result, indent=2))
return result["reads"]
return list(result["reads"])
except ClientResponseError as err:
# Ignore server errors for BILL requests
# that can happen if end_date is before account activation
if err.status == 500 and aggregate_type == AggregateType.BILL:
return []
raise err

def _get_headers(self):
def _get_headers(self) -> dict[str, str]:
headers = {"User-Agent": USER_AGENT}
if self.access_token:
headers["authorization"] = f"Bearer {self.access_token}"
Expand Down
12 changes: 7 additions & 5 deletions src/opower/utilities/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base class that each utility needs to extend."""


from typing import Optional
from typing import Any, Optional

import aiohttp

Expand All @@ -11,7 +11,7 @@ class UtilityBase:

subclasses: list[type["UtilityBase"]] = []

def __init_subclass__(cls, **kwargs) -> None:
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Keep track of all subclass implementations."""
super().__init_subclass__(**kwargs)
cls.subclasses.append(cls)
Expand All @@ -35,7 +35,7 @@ def timezone() -> str:
raise NotImplementedError

@staticmethod
def accepts_mfa() -> str:
def accepts_mfa() -> bool:
"""Check if Utility implementations supports MFA."""
return False

Expand All @@ -45,8 +45,10 @@ async def async_login(
username: str,
password: str,
optional_mfa_secret: Optional[str],
) -> str | None:
"""Login to the utility website and authorize opower.
) -> Optional[str]:
"""Login to the utility website.
Return the Opower access token or None if this function authorizes with Opower in other ways.
:raises InvalidAuth: if login information is incorrect
"""
Expand Down
15 changes: 9 additions & 6 deletions src/opower/utilities/coned.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def timezone() -> str:
return "America/New_York"

@staticmethod
def accepts_mfa() -> str:
def accepts_mfa() -> bool:
"""Check if Utility implementations supports MFA."""
return True

Expand All @@ -46,7 +46,7 @@ async def async_login(
username: str,
password: str,
optional_mfa_secret: Optional[str],
) -> None:
) -> str:
"""Login to the utility website."""
# Double-logins are somewhat broken if cookies stay around.
# Let's clear everything except device tokens (which allow skipping 2FA)
Expand Down Expand Up @@ -76,10 +76,12 @@ async def async_login(
redirectUrl = result["authRedirectUrl"]
else:
if result["newDevice"]:
if not result["noMfa"] and not optional_mfa_secret:
raise InvalidAuth("TOTP secret is required for MFA accounts")

if not result["noMfa"]:
if not optional_mfa_secret:
raise InvalidAuth(
"TOTP secret is required for MFA accounts"
)

mfaCode = TOTP(optional_mfa_secret).now()

async with session.post(
Expand All @@ -101,6 +103,7 @@ async def async_login(
else:
raise InvalidAuth("Login Failed")

assert redirectUrl
async with session.get(
redirectUrl,
headers={
Expand All @@ -116,4 +119,4 @@ async def async_login(
headers={"User-Agent": USER_AGENT},
raise_for_status=True,
) as resp:
return await resp.json()
return await resp.text()
Loading

0 comments on commit d983dca

Please sign in to comment.