Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for yaml_wrapper #107

Merged
merged 6 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 97 additions & 30 deletions moulin/yaml_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Wrappers for yaml.Node that provide better API
"""

from typing import Optional, List, Tuple, Iterator, Union
from typing import Optional, List, Tuple, Iterator, Union, Dict

from yaml.nodes import MappingNode, ScalarNode, SequenceNode, Node
from yaml.constructor import SafeConstructor
Expand All @@ -18,10 +18,10 @@
class _YamlDefaultValue:
"""
Helper class that have the same API as YamlValue, but is
constructed from a primitive type. It is used to provide default
constructed from a builtin type. It is used to provide default
value in YamlValue.get() method
"""
def __init__(self, val: Union[bool, str, int, float, None]):
def __init__(self, val: Union[bool, str, int, float, List, Dict, None]):
self._val = val

def __bool__(self):
Expand All @@ -45,16 +45,83 @@ def as_str(self) -> str:
def as_int(self) -> int:
"Get the integer value"
if not isinstance(self._val, int):
raise TypeError("Expected int value")
raise TypeError("Expected integer value")
return self._val

@property
def as_float(self) -> float:
"Get the floating point value"
if not isinstance(self._val, int):
raise TypeError("Expected float value")
if not isinstance(self._val, float):
raise TypeError("Expected floating point value")
return self._val

@property
def is_list(self) -> bool:
"""Check if this node represents a list"""
return isinstance(self._val, list)

def __iter__(self) -> Iterator["_YamlDefaultValue"]:
if not isinstance(self._val, list):
raise TypeError("Expected list value")
for item in self._val:
# We need to wrap the value in _YamlDefaultValue to provide the same API
yield _YamlDefaultValue(item)

def __len__(self) -> int:
if not isinstance(self._val, list):
raise TypeError("Expected list value")
return len(self._val)

def __getitem__(self, idx: Union[int, str]) -> "_YamlDefaultValue":
if isinstance(idx, int):
if not isinstance(self._val, list):
raise TypeError("Expected list value")
# We need to wrap the value in _YamlDefaultValue to provide the same API
return _YamlDefaultValue(self._val[idx])
elif isinstance(idx, str):
if not isinstance(self._val, dict):
raise TypeError("Expected dict value")
return _YamlDefaultValue(self._val[idx])
else:
raise KeyError("Key should have either type 'str' or 'int'")

def __setitem__(self, idx: Union[int, str], val: Union[str, int, bool, float]):
if isinstance(idx, int):
if not isinstance(self._val, list):
raise TypeError("Expected list value")
self._val[idx] = val
elif isinstance(idx, str):
if not isinstance(self._val, dict):
raise TypeError("Expected dict value")
self._val[idx] = val
else:
raise KeyError("Key should have either type 'str' or 'int")

def _get(self, name: str) -> Optional["_YamlDefaultValue"]:
if not isinstance(self._val, dict):
raise TypeError("Expected dict value")
if name in self._val:
return _YamlDefaultValue(self._val[name])
return None

def get(self, name: str, default) -> "_YamlDefaultValue":
val = self._get(name)
if val:
return val
return _YamlDefaultValue(default)

def keys(self) -> List[str]:
"""Get all keys for this mapping"""
if not isinstance(self._val, dict):
raise TypeError("Expected dict value")
return list(self._val.keys())

def items(self) -> List[Tuple[str, "_YamlDefaultValue"]]:
"""Get all items for this mapping"""
if not isinstance(self._val, dict):
raise TypeError("Expected dict value")
return [(key, _YamlDefaultValue(val)) for key, val in self._val.items()]


class YamlValue: # pylint: disable=too-few-public-methods
"""Wrapper for yaml.Node class. It provides type-safe access to YAML nodes"""
Expand Down Expand Up @@ -132,12 +199,6 @@ def items(self) -> List[Tuple[str, "YamlValue"]]:
raise YAMLProcessingError("Mapping node is expected", self.mark)
return [(key.value, YamlValue(val)) for key, val in self._node.value]

def replace_value(self, val: Union[str, int, bool, float]):
"Set a new value for a scalar node"
if not isinstance(self._node, ScalarNode):
raise YAMLProcessingError("Can't replace value for a non-scalar node", self.mark)
self._node.value = val

def __getitem__(self, idx: Union[str, int]) -> "YamlValue":
if isinstance(idx, int):
if not isinstance(self._node, SequenceNode):
Expand All @@ -150,28 +211,34 @@ def __getitem__(self, idx: Union[str, int]) -> "YamlValue":
return val
raise KeyError("Key should have either type 'str' or 'int'")

def _represent_value(self, val: Union[str, int, bool, float]) -> Node:
representer = SafeRepresenter()
if isinstance(val, str):
return representer.represent_str(val)
if isinstance(val, int):
return representer.represent_int(val)
if isinstance(val, bool):
return representer.represent_bool(val)
if isinstance(val, float):
return representer.represent_float(val)
raise TypeError(f"Unsupported type {type(val)}")

def __setitem__(self, idx: Union[str, int], val: Union[str, int, bool, float]):
# We need to make a copy of the node because yaml modules caches the nodes
# and will ignore the new value
if isinstance(idx, int):
if not isinstance(self._node, SequenceNode):
raise YAMLProcessingError("SequenceNode node is expected", self.mark)
self._node.value[idx].replace_value(val)
if isinstance(idx, str):
item = self._get(idx)
if item:
item.replace_value(val)
else:
representer = SafeRepresenter()
key_node = representer.represent_str(idx)
if isinstance(val, str):
val_node = representer.represent_str(val)
elif isinstance(val, int):
val_node = representer.represent_int(val)
elif isinstance(val, bool):
val_node = representer.represent_bool(val)
else:
val_node = representer.represent_float(val)
self._node.value.append((key_node, val_node))
raise KeyError("Key should have either type 'str' or 'int'")
self._node.value[idx] = self._represent_value(val)
elif isinstance(idx, str):
for k, v in self._node.value:
if k.value == idx:
self._node.value.remove((k, v))
key_node = self._represent_value(idx)
val_node = self._represent_value(val)
self._node.value.append((key_node, val_node))
else:
raise KeyError("Key should have either type 'str' or 'int'")

def __iter__(self) -> Iterator["YamlValue"]:
for item in self._node.value:
Expand Down
205 changes: 205 additions & 0 deletions tests/unittest/test_yaml_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""
Tests for Yaml wrapper
"""


import unittest
import yaml
from moulin.yaml_wrapper import YamlValue, _YamlDefaultValue
from moulin.yaml_helpers import YAMLProcessingError
from typing import Tuple


def gen_wrappers(val) -> Tuple[YamlValue, _YamlDefaultValue]:
mnode = yaml.compose(f"test: {val}")
_, node = mnode.value[0]
yval = YamlValue(node)
ydefval = _YamlDefaultValue(val)
return yval, ydefval


class TestMoulinYamlWrapper(unittest.TestCase):

def test_str(self):
val = "test"
intval = 42

yval, ydefval = gen_wrappers(val)

self.assertEqual(yval.as_str, val)
self.assertEqual(ydefval.as_str, val)

yval, ydefval = gen_wrappers(intval)

with self.assertRaisesRegex(YAMLProcessingError, "Expected string value"):
yval.as_str

with self.assertRaisesRegex(TypeError, "Expected string value"):
ydefval.as_str

def test_int(self):
val = 42
strval = "test"

yval, ydefval = gen_wrappers(val)

self.assertEqual(yval.as_int, val)
self.assertEqual(ydefval.as_int, val)

yval, ydefval = gen_wrappers(strval)

with self.assertRaisesRegex(YAMLProcessingError, "Expected integer value"):
yval.as_int

with self.assertRaisesRegex(TypeError, "Expected integer value"):
ydefval.as_int

def test_float(self):
val = 42.0
intval = 42

yval, ydefval = gen_wrappers(val)

self.assertEqual(yval.as_float, val)
self.assertEqual(ydefval.as_float, val)

yval, ydefval = gen_wrappers(intval)

with self.assertRaisesRegex(YAMLProcessingError, "Expected floating point value"):
yval.as_float

with self.assertRaisesRegex(TypeError, "Expected floating point value"):
ydefval.as_float

def test_boolean(self):
val = True
intval = 42

yval, ydefval = gen_wrappers(val)

self.assertEqual(yval.as_bool, val)
self.assertEqual(ydefval.as_bool, val)

yval, ydefval = gen_wrappers(intval)

with self.assertRaisesRegex(YAMLProcessingError, "Expected boolean value"):
yval.as_bool

with self.assertRaisesRegex(TypeError, "Expected boolean value"):
ydefval.as_bool

def test_list(self):
val = [1, 2, 3]
doc = """
test:
- 1
- 2
- 3
"""
mnode = yaml.compose(doc)
_, node = mnode.value[0]
yval = YamlValue(node)
ydefval = _YamlDefaultValue(val)

self.assertTrue(yval.is_list)
self.assertTrue(ydefval.is_list)
self.assertEqual(len(yval), len(val))
self.assertEqual(len(ydefval), len(val))

for i in range(len(val)):
self.assertEqual(yval[i].as_int, val[i])
self.assertEqual(ydefval[i].as_int, val[i])

for i, item in enumerate(yval):
self.assertEqual(item.as_int, val[i])

for i, item in enumerate(ydefval):
self.assertEqual(item.as_int, val[i])

yval[0] = 4
ydefval[0] = 4
self.assertEqual(yval[0].as_int, 4)
self.assertEqual(ydefval[0].as_int, 4)

with self.assertRaises(IndexError):
yval[5]
with self.assertRaises(IndexError):
ydefval[5]

def test_list_false(self):
val = 1

yval, ydefval = gen_wrappers(val)

self.assertFalse(yval.is_list)
self.assertFalse(ydefval.is_list)

with self.assertRaisesRegex(YAMLProcessingError, "SequenceNode node is expected"):
yval[0]

with self.assertRaisesRegex(TypeError, "Expected list value"):
ydefval[0]

def test_dict(self):
val = {"A": 1, "B": 2, "C": 3}
doc = """
test:
A: 1
B: 2
C: 3
"""
mnode = yaml.compose(doc)
_, node = mnode.value[0]
yval = YamlValue(node)
ydefval = _YamlDefaultValue(val)

for k in val.keys():
self.assertTrue(k in yval.keys())
self.assertEqual(yval[k].as_int, val[k])
self.assertEqual(yval.get(k, 1).as_int, val[k])
self.assertTrue(k in ydefval.keys())
self.assertEqual(ydefval[k].as_int, val[k])
self.assertEqual(ydefval.get(k, 1).as_int, val[k])

for k, v in yval.items():
self.assertEqual(v.as_int, val[k])

for k, v in ydefval.items():
self.assertEqual(v.as_int, val[k])

self.assertEqual(yval.get("WRONGKEY", 99).as_int, 99)
self.assertEqual(ydefval.get("WRONGKEY", 99).as_int, 99)

yval["A"] = 4
ydefval["A"] = 4
yval["B"] = "test"
ydefval["B"] = "test"
yval["C"] = 42.0
ydefval["C"] = 42.0

self.assertEqual(yval["A"].as_int, 4)
self.assertEqual(ydefval["A"].as_int, 4)
self.assertEqual(yval["B"].as_str, "test")
self.assertEqual(ydefval["B"].as_str, "test")
self.assertEqual(yval["C"].as_float, 42.0)
self.assertEqual(ydefval["C"].as_float, 42.0)

with self.assertRaisesRegex(KeyError, "Key should have either type 'str' or 'int"):
yval[42.0] = 1

with self.assertRaisesRegex(KeyError, "Key should have either type 'str' or 'int"):
ydefval[42.0] = 1

def test_dict_false(self):
val = 1

yval, ydefval = gen_wrappers(val)

self.assertFalse(yval.is_list)
self.assertFalse(ydefval.is_list)

with self.assertRaisesRegex(YAMLProcessingError, "Mapping node is expected"):
yval["A"]

with self.assertRaisesRegex(TypeError, "Expected dict value"):
ydefval["A"]
Loading