diff --git a/.gitignore b/.gitignore index e4bffd81..bbdff5c2 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ dmypy.json *_files/ *.html .idea/ +drafts/ diff --git a/magic_example.ipynb b/magic_example.ipynb index d2603bc9..07b759cc 100644 --- a/magic_example.ipynb +++ b/magic_example.ipynb @@ -23,6 +23,7 @@ "# or %%ipytest test_module_name\n", "\n", "def solution_power2(x: int) -> int:\n", + " print(\"hellooo!\")\n", " return x * 2" ] }, @@ -50,7 +51,26 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "%%ipytest async magic_example \n", + "\n", + "async def solution_async() -> int:\n", + " print(\"running\")\n", + " return 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest debug magic_example \n", + "\n", + "def solution_debug() -> int:\n", + " print(\"running\")\n", + " return 3" + ] } ], "metadata": { @@ -69,7 +89,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.10" }, "vscode": { "interpreter": { diff --git a/tutorial/tests/test_magic_example.py b/tutorial/tests/test_magic_example.py index d99cfeb5..45f2abab 100644 --- a/tutorial/tests/test_magic_example.py +++ b/tutorial/tests/test_magic_example.py @@ -1,3 +1,5 @@ +import asyncio + import pytest @@ -48,6 +50,12 @@ def reference_power4(num: int) -> int: return num**4 +def test_power2_raise(function_to_test): + """The test case(s)""" + with pytest.raises(TypeError): + function_to_test("a") + + input_args = [1, 2, 3, 4, 32] @@ -67,3 +75,27 @@ def test_power3(input_arg, function_to_test): def test_power4(input_arg, function_to_test): """The test case(s)""" assert function_to_test(input_arg) == reference_power4(input_arg) + + +async def reference_async() -> int: + await asyncio.sleep(1) + return 1 + + +def test_async(function_to_test): + async def async_test(): + return 1 + + result = asyncio.run(async_test()) + user_result = asyncio.run(function_to_test()) + assert result == user_result + + +def reference_debug() -> int: + print("I print here") + return 1 + + +def test_debug(function_to_test): + print("I print here") + assert function_to_test() == 1 diff --git a/tutorial/tests/testsuite.py b/tutorial/tests/testsuite.py deleted file mode 100644 index e1c5751c..00000000 --- a/tutorial/tests/testsuite.py +++ /dev/null @@ -1,183 +0,0 @@ -"""A module to define the `%%ipytest` cell magic""" -import io -import pathlib -import re -from contextlib import redirect_stderr, redirect_stdout -from typing import Dict, Optional - -import ipynbname -import pytest -from IPython.core.display import Javascript -from IPython.core.getipython import get_ipython -from IPython.core.interactiveshell import InteractiveShell -from IPython.core.magic import Magics, cell_magic, magics_class -from IPython.display import display - -from .testsuite_helpers import ( - AstParser, - FunctionInjectionPlugin, - FunctionNotFoundError, - InstanceNotFoundError, - ResultCollector, - TestResultOutput, -) - - -def _name_from_line(line: str = ""): - return line.strip().removesuffix(".py") if line else None - - -def _name_from_ipynbname() -> str | None: - try: - return ipynbname.name() - except FileNotFoundError: - return None - - -def _name_from_globals(globals_dict: Dict) -> str | None: - """Find the name of the test module from the globals dictionary if working in VSCode""" - - module_path = globals_dict.get("__vsc_ipynb_file__") if globals_dict else None - return pathlib.Path(module_path).stem if module_path else None - - -def get_module_name(line: str, globals_dict: Dict) -> str: - """Fetch the test module name""" - - module_name = ( - _name_from_line(line) - or _name_from_ipynbname() - or _name_from_globals(globals_dict) - ) - - if not module_name: - raise ModuleNotFoundError(module_name) - - return module_name - - -@magics_class -class TestMagic(Magics): - """Class to add the test cell magic""" - - shell: Optional[InteractiveShell] # type: ignore - cells: Dict[str, int] = {} - - @cell_magic - def ipytest(self, line: str, cell: str): - """The `%%ipytest` cell magic""" - # Check that the magic is called from a notebook - if not self.shell: - raise InstanceNotFoundError("InteractiveShell") - - # Get the module containing the test(s) - module_name = get_module_name(line, self.shell.user_global_ns) - - # Check that the test module file exists - module_file = pathlib.Path(f"tutorial/tests/test_{module_name}.py") - if not module_file.exists(): - raise FileNotFoundError(f"Module file '{module_file}' does not exist") - - # Run the cell through IPython - result = self.shell.run_cell(cell) - - try: - result.raise_error() - - # Retrieve the functions names defined in the current cell - # Only functions with names starting with `solution_` will be candidates for tests - functions_names = re.findall(r"^def\s+(solution_.*?)\s*\(", cell, re.M) - - # Get the functions objects from user namespace - functions_to_run = {} - for name, function in self.shell.user_ns.items(): - if name in functions_names and callable(function): - functions_to_run[name.removeprefix("solution_")] = function - - if not functions_to_run: - raise FunctionNotFoundError - - # Store execution count information for each cell - if (ipython := get_ipython()) is None: - raise InstanceNotFoundError("IPython") - - cell_id = ipython.parent_header["metadata"]["cellId"] - if cell_id in self.cells: - self.cells[cell_id] += 1 - else: - self.cells[cell_id] = 1 - - # Parse the AST tree of the file containing the test functions, - # to extract and store all information of function definitions and imports - ast_parser = AstParser(module_file) - - outputs = [] - for name, function in functions_to_run.items(): - # Create the test collector - result_collector = ResultCollector() - # Run the tests - with redirect_stderr(io.StringIO()) as pytest_stderr, redirect_stdout( - io.StringIO() - ) as pytest_stdout: - result = pytest.main( - [ - "-q", - f"{module_file}::test_{name}", - ], - plugins=[ - FunctionInjectionPlugin(function), - result_collector, - ], - ) - # Read pytest output to prevent it from being displayed - pytest_stdout.getvalue() - pytest_stderr.getvalue() - - # reset execution count on success - success = result == pytest.ExitCode.OK - if success: - self.cells[cell_id] = 0 - - outputs.append( - TestResultOutput( - list(result_collector.tests.values()), - name, - False, - success, - self.cells[cell_id], - ast_parser.get_solution_code(name), - ) - ) - - display(*outputs) - - # hide cell outputs that were not generated by a function - display( - Javascript( - """ - var output_divs = document.querySelectorAll(".jp-OutputArea-executeResult"); - for (let div of output_divs) { - div.setAttribute("style", "display: none;"); - } - """ - ) - ) - - except SyntaxError: - # Catches syntax errors - display( - TestResultOutput( - syntax_error=True, - success=False, - ) - ) - - -def load_ipython_extension(ipython): - """ - Any module file that define a function named `load_ipython_extension` - can be loaded via `%load_ext module.path` or be configured to be - autoloaded by IPython at startup time. - """ - - ipython.register_magics(TestMagic) diff --git a/tutorial/tests/testsuite/__init__.py b/tutorial/tests/testsuite/__init__.py new file mode 100644 index 00000000..69479145 --- /dev/null +++ b/tutorial/tests/testsuite/__init__.py @@ -0,0 +1 @@ +from .testsuite import load_ipython_extension # noqa diff --git a/tutorial/tests/testsuite/ast_parser.py b/tutorial/tests/testsuite/ast_parser.py new file mode 100644 index 00000000..8e45c3f3 --- /dev/null +++ b/tutorial/tests/testsuite/ast_parser.py @@ -0,0 +1,92 @@ +import ast +import pathlib +from typing import Dict, Set + + +class AstParser: + """ + Helper class for extraction of function definitions and imports. + To find all reference solutions: + Parse the module file using the AST module and retrieve all function definitions and imports. + For each reference solution store the names of all other functions used inside of it. + """ + + def __init__(self, module_file: pathlib.Path) -> None: + self.module_file = module_file + self.function_defs = {} + self.function_imports = {} + self.called_function_names = {} + + tree = ast.parse(self.module_file.read_text(encoding="utf-8")) + + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + self.function_defs[node.name] = node + elif isinstance(node, (ast.Import, ast.ImportFrom)) and hasattr( + node, "module" + ): + for n in node.names: + self.function_imports[n.name] = node.module + + for node in tree.body: + if ( + node in self.function_defs.values() + and hasattr(node, "name") + and node.name.startswith("reference_") + ): + self.called_function_names[node.name] = self.retrieve_functions( + {**self.function_defs, **self.function_imports}, node, {node.name} + ) + + def retrieve_functions( + self, all_functions: Dict, node: object, called_functions: Set[object] + ) -> Set[object]: + """ + Recursively walk the AST tree to retrieve all function definitions in a file + """ + + if isinstance(node, ast.AST): + for n in ast.walk(node): + match n: + case ast.Call(ast.Name(id=name)): + called_functions.add(name) + if name in all_functions: + called_functions = self.retrieve_functions( + all_functions, all_functions[name], called_functions + ) + for child in ast.iter_child_nodes(n): + called_functions = self.retrieve_functions( + all_functions, child, called_functions + ) + + return called_functions + + def get_solution_code(self, name: str) -> str: + """ + Find the respective reference solution for the executed function. + Create a str containing its code and the code of all other functions used, + whether coming from the same file or an imported one. + """ + + solution_functions = self.called_function_names[f"reference_{name}"] + solution_code = "" + + for f in solution_functions: + if f in self.function_defs: + solution_code += ast.unparse(self.function_defs[f]) + "\n\n" + elif f in self.function_imports: + function_file = pathlib.Path( + f"{self.function_imports[f].replace('.', '/')}.py" + ) + if function_file.exists(): + function_file_tree = ast.parse( + function_file.read_text(encoding="utf-8") + ) + for node in function_file_tree.body: + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == f + ): + solution_code += ast.unparse(node) + "\n\n" + + return solution_code diff --git a/tutorial/tests/testsuite/exceptions.py b/tutorial/tests/testsuite/exceptions.py new file mode 100644 index 00000000..11e2bfc1 --- /dev/null +++ b/tutorial/tests/testsuite/exceptions.py @@ -0,0 +1,19 @@ +class FunctionNotFoundError(Exception): + """Custom exception raised when the solution code cannot be parsed""" + + def __init__(self) -> None: + super().__init__("No functions to test defined in the cell") + + +class InstanceNotFoundError(Exception): + """Custom exception raised when an instance cannot be found""" + + def __init__(self, name: str) -> None: + super().__init__(f"Could not get {name} instance") + + +class TestModuleNotFoundError(Exception): + """Custom exception raised when the test module cannot be found""" + + def __init__(self) -> None: + super().__init__("Test module is not defined") diff --git a/tutorial/tests/testsuite/helpers.py b/tutorial/tests/testsuite/helpers.py new file mode 100644 index 00000000..b4eec49f --- /dev/null +++ b/tutorial/tests/testsuite/helpers.py @@ -0,0 +1,381 @@ +import html +import re +import traceback +from dataclasses import dataclass +from enum import Enum +from types import TracebackType +from typing import Callable, ClassVar, Dict, List, Optional + +import ipywidgets +import pytest +from IPython.display import Code +from IPython.display import display as ipython_display +from ipywidgets import HTML + + +class TestOutcome(Enum): + PASS = 1 + FAIL = 2 + TEST_ERROR = 3 + + +class IPytestOutcome(Enum): + FINISHED = 0 + COMPILE_ERROR = 1 + SOLUTION_FUNCTION_MISSING = 2 + NO_TEST_FOUND = 3 + PYTEST_ERROR = 4 + UNKNOWN_ERROR = 5 + + +@dataclass +class TestCaseResult: + """Container class to store the test results when we collect them""" + + test_name: str + outcome: TestOutcome + exception: BaseException | None + traceback: TracebackType | None + stdout: str = "" + stderr: str = "" + + +@dataclass +class IPytestResult: + function_name: Optional[str] = None + status: Optional[IPytestOutcome] = None + test_results: Optional[List[TestCaseResult]] = None + exceptions: Optional[List[BaseException]] = None + test_attempts: int = 0 + + +def format_error(exception: BaseException) -> str: + """ + Takes the output of traceback.format_exception_only() for an AssertionError + and returns a formatted string with clear, structured information. + """ + formatted_message = None + + # Get a string representation of the exception, without the traceback + exception_str = "".join(traceback.format_exception_only(exception)) + + # Handle the case where we were expecting an exception but none was raised + if "DID NOT RAISE" in exception_str: + pattern = r"" + match = re.search(pattern, exception_str) + + if match: + formatted_message = ( + "

Expected exception:

" + f"

Exception {html.escape(match.group(1))} was not raised.

" + ) + else: + # Regex pattern to extract relevant parts of the assertion message + pattern = ( + r"(\w+): assert (.*?) == (.*?)\n \+ where .*? = (.*?)\n \+ and .*? = (.*)" + ) + match = re.search(pattern, exception_str) + + if match: + ( + assertion_type, + actual_value, + expected_value, + actual_expression, + expected_expression, + ) = (html.escape(m) for m in match.groups()) + + # Formatting the output as HTML + formatted_message = ( + f"

{assertion_type}:

" + "" + ) + + # If we couldn't parse the exception message, just display it as is + formatted_message = formatted_message or f"

{html.escape(exception_str)}

" + + return formatted_message + + +@dataclass +class TestResultOutput: + """Class to prepare and display test results in a Jupyter notebook""" + + ipytest_result: IPytestResult + solution: Optional[str] = None + MAX_ATTEMPTS: ClassVar[int] = 3 + + def display_results(self) -> None: + """Display the test results in an output widget as a VBox""" + cells = [] + + output_cell = self.prepare_output_cell() + solution_cell = self.prepare_solution_cell() + + cells.append(output_cell) + + tests_finished = self.ipytest_result.status == IPytestOutcome.FINISHED + success = ( + all( + test.outcome == TestOutcome.PASS + for test in self.ipytest_result.test_results + ) + if self.ipytest_result.test_results + else False + ) + + if success or self.ipytest_result.test_attempts > 2: + cells.append(solution_cell) + else: + if tests_finished: + cells.append( + HTML( + "

📝 A proposed solution will appear after " + f"{TestResultOutput.MAX_ATTEMPTS - self.ipytest_result.test_attempts} " + f"more failed attempt{'s' if self.ipytest_result.test_attempts < 2 else ''}.

", + ) + ) + else: + cells.append( + HTML( + "

⚠️ Your code could not run because of an error. Please, double-check it.

" + ) + ) + + ipython_display( + ipywidgets.VBox( + children=cells, + # CSS: "border: 1px solid; border-color: lightgray; background-color: #FAFAFA; margin: 5px; padding: 10px;" + layout={ + "border": "1px solid lightgray", + "background-color": "#FAFAFA", + "margin": "5px", + "padding": "10px", + }, + ) + ) + + def prepare_solution_cell(self) -> ipywidgets.Widget: + """Prepare the cell to display the solution code""" + solution_code = ipywidgets.Output() + solution_cell = ipywidgets.Output() + + solution_cell.append_display_data(HTML("

👉 Proposed solution:

")) + + solution_code.append_display_data( + Code(language="python", data=f"{self.solution}") + ) + + solution_accordion = ipywidgets.Accordion( + titles=("Click here to reveal",), children=[solution_code] + ) + + solution_cell.append_display_data(ipywidgets.Box(children=[solution_accordion])) + + return solution_cell + + def prepare_output_cell(self) -> ipywidgets.Output: + """Prepare the cell to display the test results""" + output_cell = ipywidgets.Output() + output_cell.append_display_data( + HTML( + f'

Test Results for solution_{self.ipytest_result.function_name}

' + ) + ) + + match self.ipytest_result.status: + case IPytestOutcome.COMPILE_ERROR | IPytestOutcome.PYTEST_ERROR | IPytestOutcome.UNKNOWN_ERROR: + # We know that there is exactly one exception + assert self.ipytest_result.exceptions is not None + exception = self.ipytest_result.exceptions[0] + exceptions_str = ( + format_error(exception) if self.ipytest_result.exceptions else "" + ) + output_cell.append_display_data( + ipywidgets.VBox( + children=[ + HTML(f"

{type(exception).__name__}

"), + HTML(exceptions_str), + ] + ) + ) + + case IPytestOutcome.SOLUTION_FUNCTION_MISSING: + output_cell.append_display_data( + HTML("

Solution Function Missing

") + ) + + case IPytestOutcome.FINISHED if self.ipytest_result.test_results: + captures: Dict[str, Dict[str, str]] = {} + + for test in self.ipytest_result.test_results: + captures[test.test_name.split("::")[-1]] = { + "stdout": test.stdout, + "stderr": test.stderr, + } + + # Create lists of HTML outs and errs + outs = [ + f"

{test_name}


{captures[test_name]['stdout']}" + for test_name in captures + if captures[test_name]["stdout"] + ] + errs = [ + f"

{test_name}


{captures[test_name]['stderr']}" + for test_name in captures + if captures[test_name]["stderr"] + ] + + output_cell.append_display_data( + ipywidgets.VBox( + children=( + ipywidgets.Accordion( + children=( + ipywidgets.VBox( + children=[ + HTML(o, style={"background": "#FAFAFA"}) + for o in outs + ] + ), + ), + titles=("Captured output",), + ), + ipywidgets.Accordion( + children=( + ipywidgets.VBox( + children=[ + HTML(e, style={"background": "#FAFAFA"}) + for e in errs + ] + ), + ), + titles=("Captured error",), + ), + ) + ) + ) + + success = all( + test.outcome == TestOutcome.PASS + for test in self.ipytest_result.test_results + ) + + num_results = len(self.ipytest_result.test_results) + + output_cell.append_display_data( + HTML( + f"

👉 We ran {num_results} test{'s' if num_results > 1 else ''}. " + f"""{"All tests passed!

" if success else "Below you find the details for each test run:"}""" + ) + ) + + if not success: + for result in self.ipytest_result.test_results: + test_succeded = result.outcome == TestOutcome.PASS + test_name = result.test_name.split("::")[-1] + + output_box_children: List[ipywidgets.Widget] = [ + HTML( + f'

{"✔" if test_succeded else "❌"} Test {test_name}

', + style={ + "background": "rgba(251, 59, 59, 0.25)" + if not test_succeded + else "rgba(207, 249, 179, 0.60)" + }, + ) + ] + + if not test_succeded: + assert result.exception is not None + + output_box_children.append( + ipywidgets.Accordion( + children=[HTML(format_error(result.exception))], + titles=("Test results",), + ) + ) + + output_cell.append_display_data( + ipywidgets.VBox(children=output_box_children) + ) + + case IPytestOutcome.NO_TEST_FOUND: + output_cell.append_display_data(HTML("

No Test Found

")) + + return output_cell + + +@pytest.fixture +def function_to_test(): + """Function to test, overridden at runtime by the cell magic""" + + +class FunctionInjectionPlugin: + """A class to inject a function to test""" + + def __init__(self, function_to_test: Callable) -> None: + self.function_to_test = function_to_test + + def pytest_generate_tests(self, metafunc: pytest.Metafunc) -> None: + # Override the abstract `function_to_test` fixture function + if "function_to_test" in metafunc.fixturenames: + metafunc.parametrize("function_to_test", [self.function_to_test]) + + +class ResultCollector: + """A class that will collect the result of a test. If behaves a bit like a visitor pattern""" + + def __init__(self) -> None: + self.tests: Dict[str, TestCaseResult] = {} + + def pytest_runtest_makereport(self, item: pytest.Item, call: pytest.CallInfo): + """Called when an individual test item has finished execution.""" + if call.when == "call": + if call.excinfo is None: + # Test passes + self.tests[item.nodeid] = TestCaseResult( + test_name=item.nodeid, + outcome=TestOutcome.PASS, + stdout=call.result, + stderr=call.result, + exception=None, + traceback=None, + ) + else: + # Test fails + self.tests[item.nodeid] = TestCaseResult( + test_name=item.nodeid, + outcome=TestOutcome.FAIL, + exception=call.excinfo.value, + traceback=call.excinfo.tb, + ) + + def pytest_exception_interact( + self, call: pytest.CallInfo, report: pytest.TestReport + ): + """Called when an exception was raised which can potentially be interactively handled.""" + if (exc := call.excinfo) is not None: + # TODO: extract a stack summary from the traceback to inspect if the function to test raise an exception + # print([frame.name for frame in traceback.extract_tb(exc.tb)]) + # If something else than the test_* name is in that list, then we have a solution function that raised an exception + outcome = ( + TestOutcome.FAIL + if exc.errisinstance(AssertionError) + else TestOutcome.TEST_ERROR + ) + self.tests[report.nodeid] = TestCaseResult( + test_name=report.nodeid, + outcome=outcome, + exception=exc.value, + traceback=exc.tb, + ) + + def pytest_runtest_logreport(self, report: pytest.TestReport): + """Called to log the report of a test item.""" + if test_result := self.tests.get(report.nodeid): + test_result.stdout = report.capstdout + test_result.stderr = report.capstderr diff --git a/tutorial/tests/testsuite/testsuite.py b/tutorial/tests/testsuite/testsuite.py new file mode 100644 index 00000000..eedbdc46 --- /dev/null +++ b/tutorial/tests/testsuite/testsuite.py @@ -0,0 +1,302 @@ +"""A module to define the `%%ipytest` cell magic""" +import dataclasses +import inspect +import io +import pathlib +import re +from collections import defaultdict +from contextlib import redirect_stderr, redirect_stdout +from queue import Queue +from threading import Thread +from typing import Callable, Dict, List, Optional + +import ipynbname +import pytest +from IPython.core.interactiveshell import InteractiveShell +from IPython.core.magic import Magics, cell_magic, magics_class + +from .ast_parser import AstParser +from .exceptions import ( + FunctionNotFoundError, + InstanceNotFoundError, + TestModuleNotFoundError, +) +from .helpers import ( + FunctionInjectionPlugin, + IPytestOutcome, + IPytestResult, + ResultCollector, + TestOutcome, + TestResultOutput, +) + + +def run_test( + module_file: pathlib.Path, function_name: str, function_object: Callable +) -> IPytestResult: + """ + Run the tests for a single function + """ + with redirect_stdout(io.StringIO()) as _, redirect_stderr(io.StringIO()) as _: + # Create the test collector + result_collector = ResultCollector() + + # Run the tests + result = pytest.main( + ["-k", f"test_{function_name}", f"{module_file}"], + plugins=[ + FunctionInjectionPlugin(function_object), + result_collector, + ], + ) + + match result: + case pytest.ExitCode.OK: + return IPytestResult( + function_name=function_name, + status=IPytestOutcome.FINISHED, + test_results=list(result_collector.tests.values()), + ) + case pytest.ExitCode.TESTS_FAILED: + if any( + test.outcome == TestOutcome.TEST_ERROR + for test in result_collector.tests.values() + ): + return IPytestResult( + function_name=function_name, + status=IPytestOutcome.PYTEST_ERROR, + exceptions=[ + test.exception + for test in result_collector.tests.values() + if test.exception + ], + ) + + return IPytestResult( + function_name=function_name, + status=IPytestOutcome.FINISHED, + test_results=list(result_collector.tests.values()), + ) + case pytest.ExitCode.INTERNAL_ERROR: + return IPytestResult( + function_name=function_name, + status=IPytestOutcome.PYTEST_ERROR, + exceptions=[Exception("Internal error")], + ) + case pytest.ExitCode.NO_TESTS_COLLECTED: + return IPytestResult( + function_name=function_name, + status=IPytestOutcome.NO_TEST_FOUND, + exceptions=[FunctionNotFoundError()], + ) + + return IPytestResult( + status=IPytestOutcome.UNKNOWN_ERROR, exceptions=[Exception("Unknown error")] + ) + + +def run_test_in_thread( + module_file: pathlib.Path, + function_name: str, + function_object: Callable, + test_queue: Queue, +): + """Run the tests for a single function and put the result in the queue""" + test_queue.put(run_test(module_file, function_name, function_object)) + + +def _name_from_line(line: str = ""): + return line.strip().removesuffix(".py") if line else None + + +def _name_from_ipynbname() -> str | None: + try: + return ipynbname.name() + except FileNotFoundError: + return None + + +def _name_from_globals(globals_dict: Dict) -> str | None: + """Find the name of the test module from the globals dictionary if working in VSCode""" + + module_path = globals_dict.get("__vsc_ipynb_file__") if globals_dict else None + return pathlib.Path(module_path).stem if module_path else None + + +def get_module_name(line: str, globals_dict: Dict) -> str | None: + """Fetch the test module name""" + + module_name = ( + _name_from_line(line) + or _name_from_ipynbname() + or _name_from_globals(globals_dict) + ) + + return module_name + + +@magics_class +class TestMagic(Magics): + """Class to add the test cell magic""" + + def __init__(self, shell): + super().__init__(shell) + self.max_execution_count = 3 + self.shell: InteractiveShell = shell + self.cell: str = "" + self.module_file: Optional[pathlib.Path] = None + self.module_name: Optional[str] = None + self.threaded: Optional[bool] = None + self.test_queue: Optional[Queue[IPytestResult]] = None + self.cell_execution_count: Dict[str, Dict[str, int]] = defaultdict( + lambda: defaultdict(int) + ) + self._orig_traceback = self.shell._showtraceback # type: ignore + # This is monkey-patching suppress printing any exception or traceback + + def extract_functions_to_test(self) -> Dict[str, Callable]: + """""" + # Retrieve the functions names defined in the current cell + # Only functions with names starting with `solution_` will be candidates for tests + functions_names: List[str] = re.findall( + r"^(?:async\s+?)?def\s+(solution_.*?)\s*\(", self.cell, re.M + ) + + return { + name.removeprefix("solution_"): function + for name, function in self.shell.user_ns.items() + if name in functions_names + and (callable(function) or inspect.iscoroutinefunction(function)) + } + + def run_test(self, function_name: str, function_object: Callable) -> IPytestResult: + """Run the tests for a single function""" + assert isinstance(self.module_file, pathlib.Path) + + # Store execution count information for each cell + cell_id = str(self.shell.parent_header["metadata"]["cellId"]) # type: ignore + self.cell_execution_count[cell_id][function_name] += 1 + + # Run the tests on a separate thread + if self.threaded: + assert isinstance(self.test_queue, Queue) + thread = Thread( + target=run_test_in_thread, + args=( + self.module_file, + function_name, + function_object, + self.test_queue, + ), + ) + thread.start() + thread.join() + result = self.test_queue.get() + else: + result = run_test(self.module_file, function_name, function_object) + + match result.status: + case IPytestOutcome.FINISHED: + return dataclasses.replace( + result, + test_attempts=self.cell_execution_count[cell_id][function_name], + ) + case _: + return result + + def run_cell(self) -> List[IPytestResult]: + # Run the cell through IPython + try: + result = self.shell.run_cell(self.cell, silent=True) # type: ignore + result.raise_error() + except Exception as err: + return [ + IPytestResult( + status=IPytestOutcome.COMPILE_ERROR, + exceptions=[err], + ) + ] + + functions_to_run = self.extract_functions_to_test() + + if not functions_to_run: + return [ + IPytestResult( + status=IPytestOutcome.SOLUTION_FUNCTION_MISSING, + exceptions=[FunctionNotFoundError()], + ) + ] + + # Run the tests for each function + test_results = [ + self.run_test(name, function) for name, function in functions_to_run.items() + ] + + return test_results + + @cell_magic + def ipytest(self, line: str, cell: str): + """The `%%ipytest` cell magic""" + # Check that the magic is called from a notebook + if not self.shell: + raise InstanceNotFoundError("InteractiveShell") + + # Store the cell content + self.cell = cell + line_contents = set(line.split()) + + # Check if we need to run the tests on a separate thread + if "async" in line_contents: + line_contents.remove("async") + self.threaded = True + self.test_queue = Queue() + + # If debug is in the line, then we want to show the traceback + if "debug" in line_contents: + line_contents.remove("debug") + self.shell._showtraceback = self._orig_traceback + else: + self.shell._showtraceback = lambda *args, **kwargs: None + + # Get the module containing the test(s) + if ( + module_name := get_module_name( + " ".join(line_contents), self.shell.user_global_ns + ) + ) is None: + raise TestModuleNotFoundError + + self.module_name = module_name + + # Check that the test module file exists + if not ( + module_file := pathlib.Path(f"tutorial/tests/test_{self.module_name}.py") + ).exists(): + raise FileNotFoundError(f"Module file '{module_file}' does not exist") + + self.module_file = module_file + + # Run the cell + results = self.run_cell() + + # Parse the AST of the test module to retrieve the solution code + ast_parser = AstParser(self.module_file) + + # Display the test results and the solution code + for result in results: + solution = ( + ast_parser.get_solution_code(result.function_name) + if result.function_name + else None + ) + TestResultOutput(result, solution).display_results() + + +def load_ipython_extension(ipython): + """ + Any module file that define a function named `load_ipython_extension` + can be loaded via `%load_ext module.path` or be configured to be + autoloaded by IPython at startup time. + """ + + ipython.register_magics(TestMagic) diff --git a/tutorial/tests/testsuite_helpers.py b/tutorial/tests/testsuite_helpers.py deleted file mode 100644 index 7e52cd36..00000000 --- a/tutorial/tests/testsuite_helpers.py +++ /dev/null @@ -1,380 +0,0 @@ -import ast -import pathlib -import re -from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Set - -import ipywidgets -import pytest -from IPython.core.display import HTML, Javascript -from IPython.display import Code, display -from nbconvert import filters - - -@dataclass -class TestResult: - """Container class to store the test results when we collect them""" - - stdout: str - stderr: str - test_name: str - success: bool - - -@dataclass -class OutputConfig: - """Container class to store the information to display in the test output""" - - style: str - name: str - result: str - - -def format_success_failure( - syntax_error: bool, success: bool, name: str -) -> OutputConfig: - """ - Depending on the test results, returns a fragment that represents - either an error message, a success message, or a syntax error warning - """ - - if syntax_error: - return OutputConfig( - "alert-warning", - "Tests COULD NOT RUN for this cell.", - "🤔 Careful, looks like you have a syntax error.", - ) - - if not success: - return OutputConfig( - "alert-danger", - f"Tests FAILED for the function {name}", - "😱 Your solution was not correct!", - ) - - return OutputConfig( - "alert-success", - f"Tests PASSED for the function {name}", - "🙌 Congratulations, your solution was correct!", - ) - - -def format_long_stdout(text: str) -> str: - """ - Format the error message lines of a long test stdout - as an HTML that expands, by using the
element - """ - - stdout_body = re.split(r"_\s{3,}", text)[-1] - stdout_filtered = list( - filter(re.compile(r".*>E\s").match, stdout_body.splitlines()) - ) - stdout_str = "".join(f"

{line}

" for line in stdout_filtered) - stdout_edited = re.sub(r"E\s+[\+\s]*", "", stdout_str) - stdout_edited = re.sub( - r"\bfunction\ssolution_[\w\s\d]*", "your_solution", stdout_edited - ) - stdout_edited = re.sub(r"\breference_\w+\(", "reference_solution(", stdout_edited) - - test_runs = f""" -
- Click here to expand -
{stdout_edited}
-
- """ - return test_runs - - -class TestResultOutput(ipywidgets.VBox): - """Class to display the test results in a structured way""" - - def __init__( - self, - test_outputs: Optional[List[TestResult]] = None, - name: str = "", - syntax_error: bool = False, - success: bool = False, - cell_exec_count: int = 0, - solution_body: str = "", - ): - reveal_solution = cell_exec_count > 2 or success - output_config = format_success_failure(syntax_error, success, name) - output_cell = ipywidgets.Output() - - # For each test, create an alert box with the appropriate message, - # print the code output and display code errors in case of failure - with output_cell: - custom_div_style = '"border: 1px solid; border-color: lightgray; background-color: #FAFAFA; margin: 5px; padding: 10px;"' - display(HTML("

Test results

")) - display( - HTML( - f"""

{output_config.name}

{output_config.result}
""" - ) - ) - - if not syntax_error and isinstance(test_outputs, List): - if len(test_outputs) > 0 and test_outputs[0].stdout: - display( - HTML( - f""" -

👉 Code output:

-
{test_outputs[0].stdout}
- """ - ) - ) - - display( - HTML( - f""" -

👉 We tested your solution solution_{name} with {'1 input' if len(test_outputs) == 1 else str(len(test_outputs)) + ' different inputs'}. - {"All tests passed!

" if success else "Below you find the details for each test run:"} - """ - ) - ) - - if not success: - for test in test_outputs: - test_name = test.test_name - if match := re.search(r"\[.*?\]", test_name): - test_name = re.sub(r"\[|\]", "", match.group()) - - display( - HTML( - f""" -
-
{"✔" if test.success else "❌"} Test {test_name}
- {format_long_stdout(filters.ansi.ansi2html(test.stderr)) if not test.success else ""} -
- """ - ) - ) - - if not reveal_solution: - display( - HTML( - f"

📝 A proposed solution will appear after {3 - cell_exec_count} more failed attempt{'s' if cell_exec_count < 2 else ''}.

" - ) - ) - else: - # display syntax error custom alert - display( - HTML( - "

👉 Your code cannot run because of the following error:

" - ) - ) - - # fix syntax error styling - display( - Javascript( - """ - var syntax_error_containers = document.querySelectorAll('div[data-mime-type="application/vnd.jupyter.stderr"]'); - for (let container of syntax_error_containers) { - var syntax_error_div = container.parentNode; - var container_div = syntax_error_div.parentNode; - const container_style = "position: relative; padding-bottom: " + syntax_error_div.clientHeight + "px;"; - container_div.setAttribute("style", container_style); - syntax_error_div.setAttribute("style", "position: absolute; bottom: 0;"); - } - """ - ) - ) - - # fix css styling - display( - Javascript( - """ - var divs = document.querySelectorAll(".jupyter-widget-Collapse-contents"); - for (let div of divs) { - div.setAttribute("style", "padding: 0"); - } - divs = document.querySelectorAll(".widget-vbox"); - for (let div of divs) { - div.setAttribute("style", "background: #EAF0FB"); - } - """ - ) - ) - - display( - Javascript( - """ - var output_divs = document.querySelectorAll(".jp-Cell-outputArea"); - for (let div of output_divs) { - var div_str = String(div.innerHTML); - if (div_str.includes("alert-success") | div_str.includes("alert-danger")) { - div.setAttribute("style", "padding-bottom: 0;"); - } - } - """ - ) - ) - - # After 3 failed attempts or on success, reveal the proposed solution - # using a Code box inside an Accordion to display the str containing all code - solution_output = ipywidgets.Output() - with solution_output: - display(HTML("

👉 Proposed solution:

")) - - solution_code = ipywidgets.Output() - with solution_code: - display(Code(language="python", data=f"{solution_body}")) - - solution_accordion = ipywidgets.Accordion( - titles=("Click here to reveal",), children=[solution_code] - ) - - solution_box = ipywidgets.Box( - children=[solution_output, solution_accordion], - layout={ - "display": "block" if reveal_solution else "none", - "padding": "0 20px 0 0", - }, - ) - - super().__init__(children=[output_cell, solution_box]) - - -@pytest.fixture -def function_to_test(): - """Function to test, overridden at runtime by the cell magic""" - - -class FunctionInjectionPlugin: - """A class to inject a function to test""" - - def __init__(self, function_to_test: Callable) -> None: - self.function_to_test = function_to_test - - def pytest_generate_tests(self, metafunc: pytest.Metafunc) -> None: - # Override the abstract `function_to_test` fixture function - if "function_to_test" in metafunc.fixturenames: - metafunc.parametrize("function_to_test", [self.function_to_test]) - - -class ResultCollector: - """A class that will collect the result of a test. If behaves a bit like a visitor pattern""" - - def __init__(self) -> None: - self.tests: Dict[str, TestResult] = {} - - def pytest_runtest_logreport(self, report: pytest.TestReport): - # Only collect the results if it did not fail - if report.when == "teardown" and report.nodeid not in self.tests: - self.tests[report.nodeid] = TestResult( - report.capstdout, report.capstderr, report.nodeid, not report.failed - ) - - def pytest_exception_interact( - self, node: pytest.Item, call: pytest.CallInfo, report: pytest.TestReport - ): - # We need to collect the results and the stderr if the test failed - if report.failed: - self.tests[node.nodeid] = TestResult( - report.capstdout, - str(call.excinfo.getrepr() if call.excinfo else ""), - report.nodeid, - False, - ) - - -class AstParser: - """ - Helper class for extraction of function definitions and imports. - To find all reference solutions: - Parse the module file using the AST module and retrieve all function definitions and imports. - For each reference solution store the names of all other functions used inside of it. - """ - - def __init__(self, module_file: pathlib.Path) -> None: - self.module_file = module_file - self.function_defs = {} - self.function_imports = {} - self.called_function_names = {} - - tree = ast.parse(self.module_file.read_text(encoding="utf-8")) - - for node in tree.body: - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - self.function_defs[node.name] = node - elif isinstance(node, (ast.Import, ast.ImportFrom)) and hasattr( - node, "module" - ): - for n in node.names: - self.function_imports[n.name] = node.module - - for node in tree.body: - if ( - node in self.function_defs.values() - and hasattr(node, "name") - and node.name.startswith("reference_") - ): - self.called_function_names[node.name] = self.retrieve_functions( - {**self.function_defs, **self.function_imports}, node, {node.name} - ) - - def retrieve_functions( - self, all_functions: Dict, node: object, called_functions: Set[object] - ) -> Set[object]: - """ - Recursively walk the AST tree to retrieve all function definitions in a file - """ - - if isinstance(node, ast.AST): - for n in ast.walk(node): - match n: - case ast.Call(ast.Name(id=name)): - called_functions.add(name) - if name in all_functions: - called_functions = self.retrieve_functions( - all_functions, all_functions[name], called_functions - ) - for child in ast.iter_child_nodes(n): - called_functions = self.retrieve_functions( - all_functions, child, called_functions - ) - - return called_functions - - def get_solution_code(self, name): - """ - Find the respective reference solution for the executed function. - Create a str containing its code and the code of all other functions used, - whether coming from the same file or an imported one. - """ - - solution_functions = self.called_function_names[f"reference_{name}"] - solution_code = "" - - for f in solution_functions: - if f in self.function_defs: - solution_code += ast.unparse(self.function_defs[f]) + "\n\n" - elif f in self.function_imports: - function_file = pathlib.Path( - f"{self.function_imports[f].replace('.', '/')}.py" - ) - if function_file.exists(): - function_file_tree = ast.parse( - function_file.read_text(encoding="utf-8") - ) - for node in function_file_tree.body: - if ( - isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) - and node.name == f - ): - solution_code += ast.unparse(node) + "\n\n" - - return solution_code - - -class FunctionNotFoundError(Exception): - """Custom exception raised when the solution code cannot be parsed""" - - def __init__(self) -> None: - super().__init__("No functions to test defined in the cell") - - -class InstanceNotFoundError(Exception): - """Custom exception raised when an instance cannot be found""" - - def __init__(self, name: str) -> None: - super().__init__(f"Could not get {name} instance")