Skip to content

Commit

Permalink
Add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Tom committed Jun 16, 2024
1 parent 4e284b6 commit a01370b
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ __pycache__
*.iml

.coverage
.python-version
.python-version
coverage.xml
6 changes: 0 additions & 6 deletions src/pydantic_avro/from_avro/class_registery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ def add_class(self, name: str, class_def: str):
"""Add a class to the registry."""
self._classes[name] = class_def

def get_class(self, name: str) -> str:
"""Get a class from the registry."""
if name not in self._classes:
raise KeyError(f"Class {name} not found in registry")
return self._classes[name]

@property
def classes(self) -> dict:
"""Get all classes in the registry."""
Expand Down
Empty file.
10 changes: 3 additions & 7 deletions src/pydantic_avro/from_avro/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ def list_type_handler(t: dict) -> str:

def map_type_handler(t: dict) -> str:
"""Get the Python type of a given Avro map type"""
if isinstance(t["type"], dict):
avro_value_type = t["type"].get("values")
else:
avro_value_type = t.get("values")

avro_value_type = t["type"].get("values")

if avro_value_type is None:
raise AttributeError("Values are required for map type")
Expand All @@ -53,9 +51,7 @@ def map_type_handler(t: dict) -> str:

def logical_type_handler(t: dict) -> str:
"""Get the Python type of a given Avro logical type"""
if isinstance(t["type"], dict):
return LOGICAL_TYPES[t["type"]["logicalType"]]
return LOGICAL_TYPES[t["logicalType"]]
return LOGICAL_TYPES[t["type"]["logicalType"]]


def enum_type_handler(t: dict) -> str:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_from_avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ def test_avsc_to_pydantic_map():
assert "class Test(BaseModel):\n" " col1: Dict[str, str]" in pydantic_code


def test_avsc_to_pydantic_map_missing_values():
with pytest.raises(AttributeError, match="Values are required for map type"):
avsc_to_pydantic(
{
"name": "Test",
"type": "record",
"fields": [
{"name": "col1", "type": {"type": "map", "values": None, "default": {}}},
],
}
)


def test_avsc_to_pydantic_map_nested_object():
pydantic_code = avsc_to_pydantic(
{
Expand Down
26 changes: 26 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import runpy
from unittest.mock import MagicMock, patch

from pydantic_avro import __main__ as main_module


@patch("pydantic_avro.__main__.convert_file")
def test_main_avro_to_pydantic(mock_convert_file):
# Call the main function with test arguments
test_args = ["avro_to_pydantic", "--asvc", "test.avsc", "--output", "output.py"]
main_module.main(test_args)

# Assert that convert_file was called with the correct arguments
mock_convert_file.assert_called_once_with("test.avsc", "output.py")


@patch.object(main_module, "main")
@patch.object(
main_module.sys, "argv", ["__main__.py", "avro_to_pydantic", "--asvc", "test.avsc", "--output", "output.py"]
)
def test_root_main(mock_main):
# Call the root_main function
main_module.root_main()

# Assert that main was called with the correct arguments
mock_main.assert_called_once_with(["avro_to_pydantic", "--asvc", "test.avsc", "--output", "output.py"])

0 comments on commit a01370b

Please sign in to comment.