Skip to content

Commit

Permalink
Support dataclass in MSONable protocol.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Sep 7, 2022
1 parent 9f61607 commit f2e5c8b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
exclude: ^(docs|.*test_files|cmd_line|dev_scripts)

default_language_version:
python: python3.8
python: python3.9

ci:
autoupdate_schedule: monthly
Expand Down
13 changes: 13 additions & 0 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
except ImportError:
orjson = None # type: ignore

try:
import dataclasses
except ImportError:
dataclasses = None # type: ignore


__version__ = "3.0.0"


Expand Down Expand Up @@ -140,6 +146,10 @@ def recursive_as_dict(obj):
return {kk: recursive_as_dict(vv) for kk, vv in obj.items()}
if hasattr(obj, "as_dict"):
return obj.as_dict()
if dataclasses is not None and dataclasses.is_dataclass(obj):
d = dataclasses.asdict(obj)
d.update({"@module": obj.__class__.__module__, "@class": obj.__class__.__name__})
return d
return obj

for c in args:
Expand Down Expand Up @@ -407,13 +417,16 @@ def process_decoded(self, d):
return UUID(d["string"])

mod = __import__(modname, globals(), locals(), [classname], 0)

if hasattr(mod, classname):
cls_ = getattr(mod, classname)
data = {k: v for k, v in d.items() if not k.startswith("@")}
if hasattr(cls_, "from_dict"):
return cls_.from_dict(data)
if pydantic is not None and issubclass(cls_, pydantic.BaseModel): # pylint: disable=E1101
return cls_(**data)
if dataclasses is not None and dataclasses.is_dataclass(cls_):
return cls_(**data)
elif np is not None and modname == "numpy" and classname == "array":
if d["dtype"].startswith("complex"):
return np.array(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pathlib
import unittest
import dataclasses
from enum import Enum

import numpy as np
Expand Down Expand Up @@ -107,6 +108,18 @@ class ClassContainingNumpyArray(MSONable):
def __init__(self, np_a):
self.np_a = np_a

@dataclasses.dataclass
class Point:
x: float = 1
y: float = 2

class Coordinates(MSONable):

def __init__(self, points):
self.points = points

def __str__(self):
return str(self.points)

class MSONableTest(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -652,6 +665,15 @@ class ModelWithMSONable(BaseModel):
assert isinstance(obj.a, GoodMSONClass)
assert obj.a.b == 1

def test_dataclass(self):

c = Coordinates([Point(1, 2), Point(3, 4)])
d = c.as_dict()
c2 = Coordinates.from_dict(d)
self.assertEqual(d["points"][0]['x'], 1)
self.assertEqual(d["points"][1]['y'], 4)
self.assertIsInstance(c2, Coordinates)
self.assertIsInstance(c2.points[0], Point)

if __name__ == "__main__":
unittest.main()

0 comments on commit f2e5c8b

Please sign in to comment.