Skip to content

Commit

Permalink
Improve test collection and output (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
edoardob90 authored Dec 8, 2023
1 parent 973c476 commit 4dc3a7e
Show file tree
Hide file tree
Showing 10 changed files with 850 additions and 565 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ dmypy.json
*_files/
*.html
.idea/
drafts/
24 changes: 22 additions & 2 deletions magic_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"# or %%ipytest test_module_name\n",
"\n",
"def solution_power2(x: int) -> int:\n",
" print(\"hellooo!\")\n",
" return x * 2"
]
},
Expand Down Expand Up @@ -50,7 +51,26 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"%%ipytest async magic_example \n",
"\n",
"async def solution_async() -> int:\n",
" print(\"running\")\n",
" return 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%ipytest debug magic_example \n",
"\n",
"def solution_debug() -> int:\n",
" print(\"running\")\n",
" return 3"
]
}
],
"metadata": {
Expand All @@ -69,7 +89,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.10"
},
"vscode": {
"interpreter": {
Expand Down
32 changes: 32 additions & 0 deletions tutorial/tests/test_magic_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest


Expand Down Expand Up @@ -48,6 +50,12 @@ def reference_power4(num: int) -> int:
return num**4


def test_power2_raise(function_to_test):
"""The test case(s)"""
with pytest.raises(TypeError):
function_to_test("a")


input_args = [1, 2, 3, 4, 32]


Expand All @@ -67,3 +75,27 @@ def test_power3(input_arg, function_to_test):
def test_power4(input_arg, function_to_test):
"""The test case(s)"""
assert function_to_test(input_arg) == reference_power4(input_arg)


async def reference_async() -> int:
await asyncio.sleep(1)
return 1


def test_async(function_to_test):
async def async_test():
return 1

result = asyncio.run(async_test())
user_result = asyncio.run(function_to_test())
assert result == user_result


def reference_debug() -> int:
print("I print here")
return 1


def test_debug(function_to_test):
print("I print here")
assert function_to_test() == 1
183 changes: 0 additions & 183 deletions tutorial/tests/testsuite.py

This file was deleted.

1 change: 1 addition & 0 deletions tutorial/tests/testsuite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .testsuite import load_ipython_extension # noqa
92 changes: 92 additions & 0 deletions tutorial/tests/testsuite/ast_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import ast
import pathlib
from typing import Dict, Set


class AstParser:
"""
Helper class for extraction of function definitions and imports.
To find all reference solutions:
Parse the module file using the AST module and retrieve all function definitions and imports.
For each reference solution store the names of all other functions used inside of it.
"""

def __init__(self, module_file: pathlib.Path) -> None:
self.module_file = module_file
self.function_defs = {}
self.function_imports = {}
self.called_function_names = {}

tree = ast.parse(self.module_file.read_text(encoding="utf-8"))

for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
self.function_defs[node.name] = node
elif isinstance(node, (ast.Import, ast.ImportFrom)) and hasattr(
node, "module"
):
for n in node.names:
self.function_imports[n.name] = node.module

for node in tree.body:
if (
node in self.function_defs.values()
and hasattr(node, "name")
and node.name.startswith("reference_")
):
self.called_function_names[node.name] = self.retrieve_functions(
{**self.function_defs, **self.function_imports}, node, {node.name}
)

def retrieve_functions(
self, all_functions: Dict, node: object, called_functions: Set[object]
) -> Set[object]:
"""
Recursively walk the AST tree to retrieve all function definitions in a file
"""

if isinstance(node, ast.AST):
for n in ast.walk(node):
match n:
case ast.Call(ast.Name(id=name)):
called_functions.add(name)
if name in all_functions:
called_functions = self.retrieve_functions(
all_functions, all_functions[name], called_functions
)
for child in ast.iter_child_nodes(n):
called_functions = self.retrieve_functions(
all_functions, child, called_functions
)

return called_functions

def get_solution_code(self, name: str) -> str:
"""
Find the respective reference solution for the executed function.
Create a str containing its code and the code of all other functions used,
whether coming from the same file or an imported one.
"""

solution_functions = self.called_function_names[f"reference_{name}"]
solution_code = ""

for f in solution_functions:
if f in self.function_defs:
solution_code += ast.unparse(self.function_defs[f]) + "\n\n"
elif f in self.function_imports:
function_file = pathlib.Path(
f"{self.function_imports[f].replace('.', '/')}.py"
)
if function_file.exists():
function_file_tree = ast.parse(
function_file.read_text(encoding="utf-8")
)
for node in function_file_tree.body:
if (
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == f
):
solution_code += ast.unparse(node) + "\n\n"

return solution_code
19 changes: 19 additions & 0 deletions tutorial/tests/testsuite/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class FunctionNotFoundError(Exception):
"""Custom exception raised when the solution code cannot be parsed"""

def __init__(self) -> None:
super().__init__("No functions to test defined in the cell")


class InstanceNotFoundError(Exception):
"""Custom exception raised when an instance cannot be found"""

def __init__(self, name: str) -> None:
super().__init__(f"Could not get {name} instance")


class TestModuleNotFoundError(Exception):
"""Custom exception raised when the test module cannot be found"""

def __init__(self) -> None:
super().__init__("Test module is not defined")
Loading

0 comments on commit 4dc3a7e

Please sign in to comment.