diff --git a/databooks/affirm.py b/databooks/affirm.py index d317d433..726c0712 100644 --- a/databooks/affirm.py +++ b/databooks/affirm.py @@ -1,7 +1,7 @@ """Functions to safely evaluate strings and inspect notebook.""" import ast +from collections import abc from copy import deepcopy -from functools import reduce from itertools import compress from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Tuple @@ -99,6 +99,26 @@ def _prioritize(field: Tuple[str, Any]) -> bool: return True return not any(isinstance(f, ast.comprehension) for f in value) + @staticmethod + def _allowed_attr(obj: Any, attr: str, is_dynamic: bool = False) -> None: + """ + Check that attribute is a key of `databooks.data_models.base.DatabooksBase`. + + If `obj` is an iterable and was computed dynamically (that is, not originally in + scope but computed from a comprehension), check attributes for all elements in + the iterable. + """ + allowed_attrs = list(dict(obj).keys()) if isinstance(obj, DatabooksBase) else () + if isinstance(obj, abc.Iterable) and is_dynamic: + for el in obj: + DatabooksParser._allowed_attr(obj=el, attr=attr) + else: + if attr not in allowed_attrs: + raise ValueError( + "Expected attribute to be one of" + f" `{allowed_attrs}`, got `{attr}` for {obj}." + ) + def _get_iter(self, node: ast.AST) -> Iterable: """Use `DatabooksParser.safe_eval_ast` to get the iterable object.""" tree = ast.Expression(body=node) @@ -131,13 +151,7 @@ def visit_comprehension(self, node: ast.comprehension) -> None: "Expected `ast.comprehension`'s target to be `ast.Name`, got" f" `ast.{type(node.target).__name__}`." ) - # If any elements in the comprehension are a `DatabooksBase` instance, then - # pass down the attributes as valid - iterable = self._get_iter(node.iter) - databooks_el = [el for el in iterable if isinstance(el, DatabooksBase)] - if databooks_el: - d_attrs = reduce(lambda a, b: {**a, **b}, [dict(el) for el in databooks_el]) - self.names[node.target.id] = DatabooksBase(**d_attrs) if databooks_el else ... + self.names[node.target.id] = self._get_iter(node.iter) self.generic_visit(node) def visit_Attribute(self, node: ast.Attribute) -> None: @@ -148,13 +162,11 @@ def visit_Attribute(self, node: ast.Attribute) -> None: f" `ast.Subscript`, got `ast.{type(node.value).__name__}`." ) if isinstance(node.value, ast.Name): - obj = self.names[node.value.id] - allowed_attrs = dict(obj).keys() if isinstance(obj, DatabooksBase) else () - if node.attr not in allowed_attrs: - raise ValueError( - "Expected attribute to be one of" - f" `{allowed_attrs}`, got `{node.attr}`" - ) + self._allowed_attr( + obj=self.names[node.value.id], + attr=node.attr, + is_dynamic=node.value.id in (self.names.keys() - self.scope.keys()), + ) self.generic_visit(node) def visit_Name(self, node: ast.Name) -> None: diff --git a/tests/files/clean.ipynb b/tests/files/clean.ipynb new file mode 100644 index 00000000..1d8c57dc --- /dev/null +++ b/tests/files/clean.ipynb @@ -0,0 +1,113 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "metadata": {}, + "cells": [ + { + "metadata": {}, + "source": [ + "# `databooks` demo!" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "random()" + ], + "cell_type": "code", + "outputs": [ + { + "data": { + "text/plain": [ + "0.8025984025011855" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "from random import random" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "print(\"It helps with resolving git conflicts but also avoiding them in the first place! \ud83c\udf89\")" + ], + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "It helps with resolving git conflicts but also avoiding them in the first place! \ud83c\udf89\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "vals = range(1,4)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "for v in vals:\n", + " print(v)" + ], + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n", + "2\n", + "3\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "print(\"As easy as running `databooks fix .`\")" + ], + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "As easy as running `databooks fix .`\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [], + "cell_type": "code", + "outputs": [], + "execution_count": null + } + ] +} diff --git a/tests/test_affirm.py b/tests/test_affirm.py index 0b4eabe7..621707a2 100644 --- a/tests/test_affirm.py +++ b/tests/test_affirm.py @@ -41,6 +41,12 @@ def test_comprehension(self) -> None: parser = DatabooksParser(n=[1, 2, 3]) assert parser.safe_eval("[i+1 for i in n]") == [2, 3, 4] + def test_nested_comprehension(self) -> None: + """Variables in nested iterables should be valid.""" + parser = DatabooksParser(m=[1, 2], n=[3, 4], o={1: 10, -1: -10}) + res_eval = parser.safe_eval("[(i+j)*k for j in m for i in n for k in o]") + assert res_eval == [4, -4, 5, -5, 5, -5, 6, -6] + def test_multiply(self) -> None: """Multiplications are valid.""" parser = DatabooksParser() @@ -79,6 +85,21 @@ def test_nested_attributes(self) -> None: parser = DatabooksParser(model=DatabooksBase(a=DatabooksBase(b=2))) assert parser.safe_eval("model.a.b") == 2 + def test_nested_attributes_comprehensions(self) -> None: + """Nested attributes from Pydantic fields in nested comprehensions are valid.""" + parser = DatabooksParser( + l1=[ + DatabooksBase(a=DatabooksBase(b=1)), + DatabooksBase(a=DatabooksBase(b=2)), + ], + l2=[ + DatabooksBase(a=DatabooksBase(b=3)), + DatabooksBase(a=DatabooksBase(b=4)), + ], + ) + res_eval = parser.safe_eval("[m1.a.b+m2.a.b for m1 in l1 for m2 in l2]") + assert res_eval == [4, 5, 5, 6] + def test_eval(self) -> None: """Trying accessing built-in `eval` raises error.""" parser = DatabooksParser() diff --git a/tests/test_cli.py b/tests/test_cli.py index 6f9a0ac2..41b0282b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -214,10 +214,10 @@ def test_assert__config(caplog: LogCaptureFixture) -> None: ) logs = list(caplog.records) assert result.exit_code == 1 - assert len(logs) == 3 + assert len(logs) == 4 assert ( logs[-1].message - == "Found issues in notebook metadata for 1 out of 2 notebooks." + == "Found issues in notebook metadata for 2 out of 3 notebooks." ) diff --git a/tests/test_recipes.py b/tests/test_recipes.py index cf2740eb..e0765a29 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -49,6 +49,12 @@ def test_no_empty_code(self) -> None: recipe = CookBook.no_empty_code.src assert affirm(nb_path=self.nb, exprs=[recipe]) is True + def test_seq_exec__clean(self) -> None: + """If no cells are executed then no cells are executed out of order.""" + recipe = CookBook.seq_exec.src + with resources.path("tests.files", "clean.ipynb") as nb: + assert affirm(nb_path=nb, exprs=[recipe]) is True + class TestCookBookBad: """Ensure desired effect for recipes.""" @@ -60,36 +66,36 @@ def nb(self) -> Path: return nb def test_has_tags(self) -> None: - """Check fail when notebook cells have no flags.""" + """Check failure when notebook cells have no flags.""" recipe = CookBook.has_tags.src assert affirm(nb_path=self.nb, exprs=[recipe]) is False def test_has_tags_code(self) -> None: - """Check fail when code cells have no flags.""" + """Check failure when code cells have no flags.""" recipe = CookBook.has_tags_code.src assert affirm(nb_path=self.nb, exprs=[recipe]) is False def test_max_cells(self) -> None: - """Check fail when notebook has more than 128 cells.""" + """Check failure when notebook has more than 128 cells.""" recipe = CookBook.max_cells.src assert affirm(nb_path=self.nb, exprs=[recipe]) is False def test_seq_exec(self) -> None: - """Check fail when notebook code cells are executed out of order.""" + """Check failure when notebook code cells are executed out of order.""" recipe = CookBook.seq_exec.src assert affirm(nb_path=self.nb, exprs=[recipe]) is False def test_seq_increase(self) -> None: - """Check fail when notebook code cells are not executed monotonically.""" + """Check failure when notebook code cells are not executed monotonically.""" recipe = CookBook.seq_increase.src assert affirm(nb_path=self.nb, exprs=[recipe]) is False def test_startswith_md(self) -> None: - """Check fail when notebook's first cell is not a markdown cell.""" + """Check failure when notebook's first cell is not a markdown cell.""" recipe = CookBook.startswith_md.src assert affirm(nb_path=self.nb, exprs=[recipe]) is False def test_no_empty_code(self) -> None: - """Check fail when notebook contains empty code cells.""" + """Check failure when notebook contains empty code cells.""" recipe = CookBook.no_empty_code.src assert affirm(nb_path=self.nb, exprs=[recipe]) is False