Skip to content

Commit

Permalink
Merge pull request #36 from datarootsio/fix-recipes-for-clean-notebooks
Browse files Browse the repository at this point in the history
Fix recipes for clean notebooks
  • Loading branch information
murilo-cunha authored Oct 10, 2022
2 parents 5c64295 + efdbf02 commit 994c7d2
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 24 deletions.
42 changes: 27 additions & 15 deletions databooks/affirm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Functions to safely evaluate strings and inspect notebook."""
import ast
from collections import abc
from copy import deepcopy
from functools import reduce
from itertools import compress
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Tuple
Expand Down Expand Up @@ -99,6 +99,26 @@ def _prioritize(field: Tuple[str, Any]) -> bool:
return True
return not any(isinstance(f, ast.comprehension) for f in value)

@staticmethod
def _allowed_attr(obj: Any, attr: str, is_dynamic: bool = False) -> None:
"""
Check that attribute is a key of `databooks.data_models.base.DatabooksBase`.
If `obj` is an iterable and was computed dynamically (that is, not originally in
scope but computed from a comprehension), check attributes for all elements in
the iterable.
"""
allowed_attrs = list(dict(obj).keys()) if isinstance(obj, DatabooksBase) else ()
if isinstance(obj, abc.Iterable) and is_dynamic:
for el in obj:
DatabooksParser._allowed_attr(obj=el, attr=attr)
else:
if attr not in allowed_attrs:
raise ValueError(
"Expected attribute to be one of"
f" `{allowed_attrs}`, got `{attr}` for {obj}."
)

def _get_iter(self, node: ast.AST) -> Iterable:
"""Use `DatabooksParser.safe_eval_ast` to get the iterable object."""
tree = ast.Expression(body=node)
Expand Down Expand Up @@ -131,13 +151,7 @@ def visit_comprehension(self, node: ast.comprehension) -> None:
"Expected `ast.comprehension`'s target to be `ast.Name`, got"
f" `ast.{type(node.target).__name__}`."
)
# If any elements in the comprehension are a `DatabooksBase` instance, then
# pass down the attributes as valid
iterable = self._get_iter(node.iter)
databooks_el = [el for el in iterable if isinstance(el, DatabooksBase)]
if databooks_el:
d_attrs = reduce(lambda a, b: {**a, **b}, [dict(el) for el in databooks_el])
self.names[node.target.id] = DatabooksBase(**d_attrs) if databooks_el else ...
self.names[node.target.id] = self._get_iter(node.iter)
self.generic_visit(node)

def visit_Attribute(self, node: ast.Attribute) -> None:
Expand All @@ -148,13 +162,11 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
f" `ast.Subscript`, got `ast.{type(node.value).__name__}`."
)
if isinstance(node.value, ast.Name):
obj = self.names[node.value.id]
allowed_attrs = dict(obj).keys() if isinstance(obj, DatabooksBase) else ()
if node.attr not in allowed_attrs:
raise ValueError(
"Expected attribute to be one of"
f" `{allowed_attrs}`, got `{node.attr}`"
)
self._allowed_attr(
obj=self.names[node.value.id],
attr=node.attr,
is_dynamic=node.value.id in (self.names.keys() - self.scope.keys()),
)
self.generic_visit(node)

def visit_Name(self, node: ast.Name) -> None:
Expand Down
113 changes: 113 additions & 0 deletions tests/files/clean.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {},
"cells": [
{
"metadata": {},
"source": [
"# `databooks` demo!"
],
"cell_type": "markdown"
},
{
"metadata": {},
"source": [
"random()"
],
"cell_type": "code",
"outputs": [
{
"data": {
"text/plain": [
"0.8025984025011855"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": null
},
{
"metadata": {},
"source": [
"from random import random"
],
"cell_type": "code",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"source": [
"print(\"It helps with resolving git conflicts but also avoiding them in the first place! \ud83c\udf89\")"
],
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"It helps with resolving git conflicts but also avoiding them in the first place! \ud83c\udf89\n"
]
}
],
"execution_count": null
},
{
"metadata": {},
"source": [
"vals = range(1,4)"
],
"cell_type": "code",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"source": [
"for v in vals:\n",
" print(v)"
],
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"3\n"
]
}
],
"execution_count": null
},
{
"metadata": {},
"source": [
"print(\"As easy as running `databooks fix .`\")"
],
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"As easy as running `databooks fix .`\n"
]
}
],
"execution_count": null
},
{
"metadata": {},
"source": [],
"cell_type": "code",
"outputs": [],
"execution_count": null
}
]
}
21 changes: 21 additions & 0 deletions tests/test_affirm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def test_comprehension(self) -> None:
parser = DatabooksParser(n=[1, 2, 3])
assert parser.safe_eval("[i+1 for i in n]") == [2, 3, 4]

def test_nested_comprehension(self) -> None:
"""Variables in nested iterables should be valid."""
parser = DatabooksParser(m=[1, 2], n=[3, 4], o={1: 10, -1: -10})
res_eval = parser.safe_eval("[(i+j)*k for j in m for i in n for k in o]")
assert res_eval == [4, -4, 5, -5, 5, -5, 6, -6]

def test_multiply(self) -> None:
"""Multiplications are valid."""
parser = DatabooksParser()
Expand Down Expand Up @@ -79,6 +85,21 @@ def test_nested_attributes(self) -> None:
parser = DatabooksParser(model=DatabooksBase(a=DatabooksBase(b=2)))
assert parser.safe_eval("model.a.b") == 2

def test_nested_attributes_comprehensions(self) -> None:
"""Nested attributes from Pydantic fields in nested comprehensions are valid."""
parser = DatabooksParser(
l1=[
DatabooksBase(a=DatabooksBase(b=1)),
DatabooksBase(a=DatabooksBase(b=2)),
],
l2=[
DatabooksBase(a=DatabooksBase(b=3)),
DatabooksBase(a=DatabooksBase(b=4)),
],
)
res_eval = parser.safe_eval("[m1.a.b+m2.a.b for m1 in l1 for m2 in l2]")
assert res_eval == [4, 5, 5, 6]

def test_eval(self) -> None:
"""Trying accessing built-in `eval` raises error."""
parser = DatabooksParser()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ def test_assert__config(caplog: LogCaptureFixture) -> None:
)
logs = list(caplog.records)
assert result.exit_code == 1
assert len(logs) == 3
assert len(logs) == 4
assert (
logs[-1].message
== "Found issues in notebook metadata for 1 out of 2 notebooks."
== "Found issues in notebook metadata for 2 out of 3 notebooks."
)


Expand Down
20 changes: 13 additions & 7 deletions tests/test_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def test_no_empty_code(self) -> None:
recipe = CookBook.no_empty_code.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is True

def test_seq_exec__clean(self) -> None:
"""If no cells are executed then no cells are executed out of order."""
recipe = CookBook.seq_exec.src
with resources.path("tests.files", "clean.ipynb") as nb:
assert affirm(nb_path=nb, exprs=[recipe]) is True


class TestCookBookBad:
"""Ensure desired effect for recipes."""
Expand All @@ -60,36 +66,36 @@ def nb(self) -> Path:
return nb

def test_has_tags(self) -> None:
"""Check fail when notebook cells have no flags."""
"""Check failure when notebook cells have no flags."""
recipe = CookBook.has_tags.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is False

def test_has_tags_code(self) -> None:
"""Check fail when code cells have no flags."""
"""Check failure when code cells have no flags."""
recipe = CookBook.has_tags_code.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is False

def test_max_cells(self) -> None:
"""Check fail when notebook has more than 128 cells."""
"""Check failure when notebook has more than 128 cells."""
recipe = CookBook.max_cells.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is False

def test_seq_exec(self) -> None:
"""Check fail when notebook code cells are executed out of order."""
"""Check failure when notebook code cells are executed out of order."""
recipe = CookBook.seq_exec.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is False

def test_seq_increase(self) -> None:
"""Check fail when notebook code cells are not executed monotonically."""
"""Check failure when notebook code cells are not executed monotonically."""
recipe = CookBook.seq_increase.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is False

def test_startswith_md(self) -> None:
"""Check fail when notebook's first cell is not a markdown cell."""
"""Check failure when notebook's first cell is not a markdown cell."""
recipe = CookBook.startswith_md.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is False

def test_no_empty_code(self) -> None:
"""Check fail when notebook contains empty code cells."""
"""Check failure when notebook contains empty code cells."""
recipe = CookBook.no_empty_code.src
assert affirm(nb_path=self.nb, exprs=[recipe]) is False

0 comments on commit 994c7d2

Please sign in to comment.