Skip to content

Commit

Permalink
refactor: upgrade pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Sep 21, 2023
1 parent 07009a3 commit cc3de01
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 41 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ include = '\.pyi?$'

[tool.pytest.ini_options]
addopts = """
-n auto
-p no:ape_test
--cov-branch
--cov-report term
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ exclude =
.tox
docs
build
./tokenlists/version.py
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"pytest>=6.0", # Core testing package
"pytest-xdist", # multi-process runner
"pytest-cov", # Coverage analyzer plugin
"hypothesis>=6.2.0,<7", # Strategy-based fuzzer
"hypothesis>=6.86.2,<7", # Strategy-based fuzzer
"PyGithub>=1.54,<2", # Necessary to pull official schema from github
"hypothesis-jsonschema==0.19.0", # Fuzzes based on a json schema
],
Expand Down Expand Up @@ -70,7 +70,6 @@
"click>=8.1.3,<9",
"pydantic>=2.3.0,<3",
"pyyaml>=6.0,<7",
"semantic-version>=2.10.0,<3",
"requests>=2.28.1,<3",
],
entry_points={"console_scripts": ["tokenlists=tokenlists._cli:cli"]},
Expand Down
4 changes: 3 additions & 1 deletion tests/functional/test_schema_fuzzing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def clean_data(tl: dict) -> dict:
@settings(suppress_health_check=(HealthCheck.too_slow,))
def test_schema(token_list):
try:
assert TokenList.parse_obj(token_list).dict() == clean_data(token_list)
assert TokenList.model_validate(token_list).model_dump(mode="json") == clean_data(
token_list
)

except (ValidationError, ValueError):
pass # Expect these kinds of errors
50 changes: 44 additions & 6 deletions tests/functional/test_uniswap_examples.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from typing import Any, Optional

import github
import pytest
import requests # type: ignore[import]
import requests
from pydantic import ValidationError

from tokenlists import TokenList
Expand All @@ -27,9 +28,46 @@ def test_uniswap_tokenlists(token_list_name):
# https://github.com/Uniswap/token-lists/pull/420
token_list.pop("tokenMap")

if "invalid" not in token_list_name:
assert TokenList.parse_obj(token_list).dict() == token_list

else:
if "invalid" in token_list_name:
with pytest.raises((ValidationError, ValueError)):
TokenList.parse_obj(token_list).dict()
TokenList.model_validate(token_list)
else:
actual = TokenList.model_validate(token_list).model_dump(mode="json")

def assert_tokenlists(_actual: Any, _expected: Any, parent_key: Optional[str] = None):
parent_key = parent_key or "__root__"
assert type(_actual) is type(_expected)

if isinstance(_actual, list):
for idx, (actual_item, expected_item) in enumerate(zip(_actual, _expected)):
assert_tokenlists(
actual_item, expected_item, parent_key=f"{parent_key}_index_{idx}"
)

elif isinstance(_actual, dict):
unexpected = {}
handled = set()
for key, actual_value in _actual.items():
if key not in _expected:
unexpected[key] = actual_value
continue

expected_value = _expected[key]
assert type(actual_value) is type(expected_value)
assert_tokenlists(actual_value, expected_value, parent_key=key)
handled.add(key)

handled_str = ", ".join(list(handled)) or "<Nothing handled>"
missing = {f"{x}" for x in _expected if x not in handled}
unexpected_str = ", ".join([f"{k}={v}" for k, v in unexpected.items()])
assert not unexpected, f"Unexpected keys: {unexpected_str}, Parent: {parent_key}"
assert not missing, (
f"Missing keys: '{', '.join(list(missing))}'; "
f"handled: '{handled_str}', "
f"Parent: {parent_key}."
)

else:
assert _actual == _expected

assert_tokenlists(actual, token_list)
8 changes: 4 additions & 4 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def test_empty_list(runner, cli):


def test_install(runner, cli):
result = runner.invoke(cli, ["list"])
assert result.exit_code == 0
assert "No tokenlists exist" in result.output
# result = runner.invoke(cli, ["list"])
# assert result.exit_code == 0
# assert "No tokenlists exist" in result.output

result = runner.invoke(cli, ["install", TEST_URI])
assert result.exit_code == 0
assert result.exit_code == 0, result.output

result = runner.invoke(cli, ["list"])
assert result.exit_code == 0
Expand Down
4 changes: 2 additions & 2 deletions tokenlists/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def list_tokens(search, tokenlist_name, chain_id):
lambda t: pattern.match(t.symbol),
manager.get_tokens(tokenlist_name, chain_id),
):
click.echo("{address} ({symbol})".format(**token_info.dict()))
click.echo("{address} ({symbol})".format(**token_info.model_dump(mode="json")))


@cli.command(short_help="Display the info for a particular token")
Expand All @@ -92,7 +92,7 @@ def token_info(symbol, tokenlist_name, chain_id, case_insensitive):
raise click.ClickException("No tokenlists available!")

token_info = manager.get_token_info(symbol, tokenlist_name, chain_id, case_insensitive)
token_info = token_info.dict()
token_info = token_info.model_dump(mode="json")

if "tags" not in token_info:
token_info["tags"] = ""
Expand Down
3 changes: 2 additions & 1 deletion tokenlists/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
DEFAULT_CACHE_PATH = Path.home().joinpath(".tokenlists")
DEFAULT_TOKENLIST: Optional[str] = None

UNISWAP_ENS_TOKENLISTS_HOST = "https://tokenlists.org/token-list?url=http://{}.link"
# UNISWAP_ENS_TOKENLISTS_HOST = "https://tokenlists.org/token-list?url=http://{}.link"
UNISWAP_ENS_TOKENLISTS_HOST = "https://wispy-bird-88a7.uniswap.workers.dev/?url=http://{}.link"
6 changes: 3 additions & 3 deletions tokenlists/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):
# Load all the ones cached on disk
self.installed_tokenlists = {}
for path in self.cache_folder.glob("*.json"):
tokenlist = TokenList.parse_file(path)
tokenlist = TokenList.model_validate_json(path.read_text())
self.installed_tokenlists[tokenlist.name] = tokenlist

self.default_tokenlist = config.DEFAULT_TOKENLIST
Expand Down Expand Up @@ -48,13 +48,13 @@ def install_tokenlist(self, uri: str) -> str:
except JSONDecodeError as err:
raise ValueError(f"Invalid response: {response.text}") from err

tokenlist = TokenList.parse_obj(response_json)
tokenlist = TokenList.model_validate(response_json)
self.installed_tokenlists[tokenlist.name] = tokenlist

# Cache it on disk for later instances
self.cache_folder.mkdir(exist_ok=True)
token_list_file = self.cache_folder.joinpath(f"{tokenlist.name}.json")
token_list_file.write_text(tokenlist.json())
token_list_file.write_text(tokenlist.model_dump_json())

return tokenlist.name

Expand Down
52 changes: 31 additions & 21 deletions tokenlists/typing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from datetime import datetime
from itertools import chain
from typing import Any, Dict, List, Optional

from pydantic import AnyUrl
from pydantic import BaseModel as _BaseModel
from pydantic import validator
from semantic_version import Version # type: ignore
from pydantic import PastDatetime, field_validator

ChainId = int

TagId = str

TokenAddress = str
TokenName = str
TokenDecimals = int
Expand All @@ -19,10 +15,13 @@

class BaseModel(_BaseModel):
def dict(self, *args, **kwargs):
return self.model_dump(*args, **kwargs)

def model_dump(self, *args, **kwargs):
if "exclude_unset" not in kwargs:
kwargs["exclude_unset"] = True

return super().dict(*args, **kwargs)
return super().model_dump(*args, **kwargs)

class Config:
froze = True
Expand All @@ -44,7 +43,7 @@ class TokenInfo(BaseModel):
tags: Optional[List[TagId]] = None
extensions: Optional[Dict[str, Any]] = None

@validator("logoURI")
@field_validator("logoURI")
def validate_uri(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
Expand All @@ -54,7 +53,7 @@ def validate_uri(cls, v: Optional[str]) -> Optional[str]:

return v

@validator("extensions", pre=True)
@field_validator("extensions", mode="before")
def parse_extensions(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
# 1. Check extension depth first
def extension_depth(obj: Optional[Dict[str, Any]]) -> int:
Expand All @@ -68,12 +67,17 @@ def extension_depth(obj: Optional[Dict[str, Any]]) -> int:

# 2. Parse valid extensions
if v and "bridgeInfo" in v:
raw_bridge_info = v.pop("bridgeInfo")
v["bridgeInfo"] = {int(k): BridgeInfo.parse_obj(v) for k, v in raw_bridge_info.items()}
# NOTE: Avoid modifying `v`.
return {
**v,
"bridgeInfo": {
int(k): BridgeInfo.model_validate(v) for k, v in v["bridgeInfo"].items()
},
}

return v

@validator("extensions")
@field_validator("extensions")
def extensions_must_contain_allowed_types(
cls, d: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
Expand All @@ -96,7 +100,7 @@ def bridge_info(self) -> Optional[BridgeInfo]:

return None

@validator("address")
@field_validator("address")
def address_must_hex(cls, v: str):
if not v.startswith("0x") or set(v) > set("x0123456789abcdefABCDEF") or len(v) % 2 != 0:
raise ValueError("Address is not hex")
Expand All @@ -108,7 +112,7 @@ def address_must_hex(cls, v: str):

return v

@validator("decimals")
@field_validator("decimals")
def decimals_must_be_uint8(cls, v: TokenDecimals):
if not (0 <= v < 256):
raise ValueError(f"Invalid token decimals: {v}")
Expand All @@ -121,12 +125,12 @@ class Tag(BaseModel):
description: str


class TokenListVersion(BaseModel, Version):
class TokenListVersion(BaseModel):
major: int
minor: int
patch: int

@validator("*")
@field_validator("*")
def no_negative_version_numbers(cls, v: int):
if v < 0:
raise ValueError("Invalid version number")
Expand All @@ -150,7 +154,7 @@ def __str__(self) -> str:

class TokenList(BaseModel):
name: str
timestamp: datetime
timestamp: PastDatetime
version: TokenListVersion
tokens: List[TokenInfo]
keywords: Optional[List[str]] = None
Expand Down Expand Up @@ -180,7 +184,7 @@ class Config:
# NOTE: Not frozen as we may need to dynamically modify this
froze = False

@validator("logoURI")
@field_validator("logoURI")
def validate_uri(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
Expand All @@ -190,8 +194,14 @@ def validate_uri(cls, v: Optional[str]) -> Optional[str]:

return v

def dict(self, *args, **kwargs) -> dict:
data = super().dict(*args, **kwargs)
# NOTE: This was the easiest way to make sure this property returns isoformat
data["timestamp"] = self.timestamp.isoformat()
def dict(self, *args, **kwargs):
return self.model_dump(*args, **kwargs)

def model_dump(self, *args, **kwargs) -> Dict:
data = super().model_dump(*args, **kwargs)

if kwargs.get("mode", "").lower() == "json":
# NOTE: This was the easiest way to make sure this property returns isoformat
data["timestamp"] = self.timestamp.isoformat()

return data

0 comments on commit cc3de01

Please sign in to comment.