Skip to content

Commit

Permalink
Allow fmt to be specified.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Jul 8, 2021
1 parent 9cced45 commit 02de6b3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
37 changes: 21 additions & 16 deletions monty/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
msgpack = None


def loadfn(fn, *args, **kwargs):
def loadfn(fn, *args, fmt=None, **kwargs):
r"""
Loads json/yaml/msgpack directly from a filename instead of a
File-like object. File may also be a BZ2 (".BZ2") or GZIP (".GZ", ".Z")
Expand All @@ -50,19 +50,23 @@ def loadfn(fn, *args, **kwargs):
Args:
fn (str/Path): filename or pathlib.Path.
*args: Any of the args supported by json/yaml.load.
fmt (string): If specified, the fmt specified would be used instead
of autodetection from filename. Supported formats right now are
"json", "yaml" or "mpk".
**kwargs: Any of the kwargs supported by json/yaml.load.
Returns:
(object) Result of json/yaml/msgpack.load.
"""

basename = os.path.basename(fn).lower()
if ".mpk" in basename:
fmt = "mpk"
elif any(ext in basename for ext in (".yaml", ".yml")):
fmt = "yaml"
else:
fmt = "json"
if fmt is None:
basename = os.path.basename(fn).lower()
if ".mpk" in basename:
fmt = "mpk"
elif any(ext in basename for ext in (".yaml", ".yml")):
fmt = "yaml"
else:
fmt = "json"

if fmt == "mpk":
if msgpack is None:
Expand All @@ -87,7 +91,7 @@ def loadfn(fn, *args, **kwargs):
raise TypeError("Invalid format: {}".format(fmt))


def dumpfn(obj, fn, *args, **kwargs):
def dumpfn(obj, fn, *args, fmt=None, **kwargs):
r"""
Dump to a json/yaml directly by filename instead of a
File-like object. File may also be a BZ2 (".BZ2") or GZIP (".GZ", ".Z")
Expand All @@ -107,13 +111,14 @@ def dumpfn(obj, fn, *args, **kwargs):
Returns:
(object) Result of json.load.
"""
basename = os.path.basename(fn).lower()
if ".mpk" in basename:
fmt = "mpk"
elif any(ext in basename for ext in (".yaml", ".yml")):
fmt = "yaml"
else:
fmt = "json"
if fmt is None:
basename = os.path.basename(fn).lower()
if ".mpk" in basename:
fmt = "mpk"
elif any(ext in basename for ext in (".yaml", ".yml")):
fmt = "yaml"
else:
fmt = "json"

if fmt == "mpk":
if msgpack is None:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,21 @@ def test_dumpfn_loadfn(self):
self.assertEqual(d, d2)
os.remove("monte_test.yaml")

# Check if fmt override works.
dumpfn(d, "monte_test.json", fmt="yaml")
with self.assertRaises(json.decoder.JSONDecodeError):
d2 = loadfn("monte_test.json")
d2 = loadfn("monte_test.json", fmt="yaml")
self.assertEqual(d, d2)
os.remove("monte_test.json")

with self.assertRaises(TypeError):
dumpfn(d, "monte_test.txt", fmt="garbage")
with self.assertRaises(TypeError):
loadfn("monte_test.txt", fmt="garbage")



@unittest.skipIf(msgpack is None, "msgpack-python not installed.")
def test_mpk(self):
d = {"hello": "world"}
Expand Down

1 comment on commit 02de6b3

@jdagdelen
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Forwarded on to my colleague.

Please sign in to comment.