Skip to content

Commit

Permalink
Merge branch 'main' into draw-trivial
Browse files Browse the repository at this point in the history
  • Loading branch information
emileferreira committed Feb 5, 2024
2 parents 19a9a92 + ffe807f commit 870d7a2
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 68 deletions.
16 changes: 14 additions & 2 deletions .github/workflows/test-python.yml → .github/workflows/python.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Python tests
name: Python
on: [push, workflow_dispatch, pull_request]
jobs:
build:
test:
runs-on: ubuntu-latest
strategy:
matrix:
Expand All @@ -18,3 +18,15 @@ jobs:
pip install antlr4-python3-runtime==4.9.1
- name: Run all the Python tests
run: python3 tests/test_all.py
pep8:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: 'Run PEP8 check'
uses: quentinguidee/pep8-action@v1
with:
arguments: >-
--exclude=.svn,CVS,.bzr,.hg,.git,zzantlr
--ignore=E121,E123,E126,E133,E226,E241,E242,E704,W503,W504,W505,W191,E101,E128
9 changes: 4 additions & 5 deletions RASP_support/DrawCompFlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def makeQKStable(qvars, kvars, select, ref_in_g):
# select has qvars along the rows and kvars along the columns, so we'll do
# the same. i.e. top rows will just be the kvars and first columns will
# just be the qvars.
# if (not qvars) and (not kvars):
# if (not qvars) and (not kvars):
# # no qvars or kvars -> full select -> dont waste space drawing.
# num_rows, num_columns = 0, 0
# pass
Expand Down Expand Up @@ -562,12 +562,11 @@ def draw_comp_flow(self, w, filename=None,
keep_dot=False, show=True,
force_vertical_layers=True, add_tokens_on_ff=False):
if w is not None:
self(w) # execute seq (and all its ancestors) on the given input w.
# if w==None, assume seq has already been executed on some input.
self.call(w) # execute seq (and all its ancestors) on the given input
if not self.last_w == w:
print("evaluating input failed")
return
else:
else: # if w == None, assume seq has already been executed on some input.
w = self.last_w
if None is filename:
name = self.name
Expand All @@ -588,7 +587,7 @@ def draw_comp_flow(self, w, filename=None,
# (though it will not be able to draw computation flows without it)
from graphviz import Digraph
g = Digraph('g')
# with curved lines it fusses over separating score edges
# with curved lines it fusses over separating score edges
# and makes weirdly curved ones that start overlapping with the sequences
# :(
g.attr(splines='polyline')
Expand Down
2 changes: 1 addition & 1 deletion RASP_support/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def carefulcopy(val):
if isinstance(val, Unfinished) or isinstance(val, RASPFunction):
return val # non mutable, at least not through rasp commands
elif isinstance(val, float) or isinstance(val, int) \
or isinstance(val, str) or isinstance(val, bool):
or isinstance(val, str) or isinstance(val, bool):
return val # non mutable
elif isinstance(val, list):
return [carefulcopy(v) for v in val]
Expand Down
18 changes: 9 additions & 9 deletions RASP_support/Evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __str__(self):
return self.creator + " function: " + self.name \
+ "(" + ", ".join(self.argnames) + ")"

def __call__(self, *args):
def call(self, *args):
top_eval = args[-1]
args = args[:-1]
# nesting, because function shouldn't affect the enclosing environment
Expand Down Expand Up @@ -165,7 +165,7 @@ def draw(self, ast):
if not isinstance(unf, UnfinishedSequence):
raise RASPTypeError("draw expects unfinished sequence, got:", unf)
unf.draw_comp_flow(example)
res = unf(example)
res = unf.call(example)
res.created_from_input = example
self.backup_example = prev_backup
return JustVal(res)
Expand Down Expand Up @@ -202,7 +202,7 @@ def _set_iterator_and_vals(self, iterator_names, iterator_vals):
if len(iterator_names) == 1:
self.env.set_variable(iterator_names[0], iterator_vals)
elif isinstance(iterator_vals, Iterable) \
and (len(iterator_vals) == len(iterator_names)):
and (len(iterator_vals) == len(iterator_names)):
for n, v in zip(iterator_names, iterator_vals):
self.env.set_variable(n, v)
else:
Expand Down Expand Up @@ -248,7 +248,7 @@ def _evaluateListComp(self, ast):
for vals in ll:
orig_env = self.env
self.env = self.env.make_nested()
# sets inside the now-nested env -don't want to keep
# sets inside the now-nested env -don't want to keep
# the internal iterators after finishing this list comp
self._set_iterator_and_vals(iterator_names, vals)
res.append(self.evaluateExpr(ast.val))
Expand Down Expand Up @@ -557,15 +557,15 @@ def _evaluateApplication(self, ast, unf):
raise RASPTypeError(
"Applying unfinished expects iterable input, got:",
strdesc(input_val))
res = unf(input_val)
res = unf.call(input_val)
res.created_from_input = input_val
return res

def _evaluateRASPFunction(self, ast, raspfun):
args_trees = self._get_first_cont_list(ast.inputexprs)
args = tuple(self.evaluateExpr(t) for t in args_trees) + (self,)
real_args = args[:-1]
res = raspfun(*args)
res = raspfun.call(*args)
if isinstance(res, Unfinished):
res.setname(
raspfun.name+"("+" , ".join(strdesc(a, desc_cap=20)
Expand Down Expand Up @@ -629,8 +629,8 @@ def _test_res(self, res):
if isinstance(res, Unfinished):
def succeeds_with(exampe):
try:
res(example, just_pass_exception_up=True)
except:
res.call(example, just_pass_exception_up=True)
except Exception:
return False
else:
return True
Expand All @@ -643,7 +643,7 @@ def succeeds_with(exampe):
return
example = self.sequence_running_example if self.backup_example \
is None else self.backup_example
res(example, just_pass_exception_up=True)
res.call(example, just_pass_exception_up=True)

def evaluateExpr(self, ast, from_top=False):
def format_return(res, resname="out",
Expand Down
61 changes: 20 additions & 41 deletions RASP_support/FunctionalSupport.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_parents(self):
for p in other_parents:
# recursion: branch back through all the parents of the unf,
# always stopping wherever hit something 'real' ie a select or
# a sequence
# a sequence
res += p.get_parents()
# nothing is made from more than one select...
assert len(
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_full_parents(self, recurse=False, just_compute=False,
for p in self.get_sorted_full_parents():
p.get_full_parents(recurse=True, just_compute=True)
# have them all compute their full parents so they are
# ready for the future, but only do this in sorted order,
# ready for the future, but only do this in sorted order,
# so recursion is always shallow. (always gets shorted with
# self._full_parents, which is being computed here for each
# unfinished starting from the top of the computation
Expand Down Expand Up @@ -175,9 +175,7 @@ def get_sorted_full_parents(self):
self._sort_full_parents()
return copy(self._sorted_full_parents)

def __call__(self, w, print_all_named_sequences=False, print_input=False,
print_all_sequences=False, print_all=False, topcall=True,
just_pass_exception_up=False):
def call(self, w, topcall=True, just_pass_exception_up=False):
if (not isinstance(w, Iterable)) or (not w):
raise RASPTypeError(
"RASP sequences/selectors expect non-empty iterables, got: "
Expand All @@ -203,16 +201,11 @@ def __call__(self, w, print_all_named_sequences=False, print_input=False,
# further back as they use memoization
for unf in self.get_sorted_full_parents():
# evaluate
unf(w, topcall=False,
unf.call(w, topcall=False,
just_pass_exception_up=just_pass_exception_up)

p_a_n_s = print_all_named_sequences
j_p_e_u = just_pass_exception_up
args = tuple(p(w,
print_all_named_sequences=p_a_n_s,
print_input=print_input,
print_all_sequences=print_all_sequences,
print_all=print_all,
args = tuple(p.call(w,
topcall=False,
just_pass_exception_up=j_p_e_u)
for p in self.parents_tuple)
Expand All @@ -239,7 +232,7 @@ def __call__(self, w, print_all_named_sequences=False, print_input=False,
a, b, tb = sys.exc_info()
tt = traceback.extract_tb(tb)
last_call = max([i for i, t in enumerate(tt)
if "__call__" in str(t)])
if "in call" in str(t)])
print(''.join(traceback.format_list(tt[last_call+1:])))

# traceback.print_exception(a,b,tb)
Expand All @@ -253,20 +246,6 @@ def __call__(self, w, print_all_named_sequences=False, print_input=False,

self.last_w, self.last_res = w, res

def should_print():
if isinstance(res, Sequence):
if print_all_named_sequences and self.name not in plain_names:
return True
if print_all_sequences:
return True
if self.is_toplevel_input and print_input:
return True
return print_all
if should_print():
print("resolved \""+self.name +
(("\" from:\" "+str(self.get_own_root_input(w))+" \"")
if print_root_inputs_too else ""),
":\n\t", res)
return res


Expand All @@ -277,12 +256,12 @@ def __init__(self, parents_tuple, parents2self,
from_zipmap=False, output_index=-1,
definitely_uses_identity_function=False):
# min_poss_depth=0 starts all of the base sequences (eg indices) off
# right.
# right.

# might have got none from some default value, fix it before continuing
# because later things eg DrawCompFlow will expect name to be str
if name is None:
name = plain_unfinished_sequence_name
name = plain_unfinished_sequence_name
super(UnfinishedSequence, self).__init__(parents_tuple,
parents2self, name=name,
min_poss_depth=min_poss_depth)
Expand Down Expand Up @@ -441,13 +420,13 @@ def select(q_vars, k_vars, selector, name=None, compare_string=None):
# helpful for the user so consider maybe adding a tiny bit of mess here
# (including markings inside sequences and selectors so they know which
# index they're gathering to and from) to allow it

# we're ok with getting a single q or k var, not in a tuple,
# but important to fix it before '+' on two UnfinishedSequences
# (as opposed to two tuples) sends everything sideways
q_vars = tupleise(q_vars)
k_vars = tupleise(k_vars)

# attn layer is one after values it needs to be calculated
new_depth = _min_poss_depth(q_vars+k_vars)+1
res = UnfinishedSelect((_input, # need input seq length to create select
Expand Down Expand Up @@ -548,19 +527,19 @@ def parents2res(w, vt): return _zipmap(len(w), vt, elementwise_function)
# you can do it in the embedding
# if len(sequences_tuple)>0:
# min_poss_depth = max(min_poss_depth,1) # except for the very specific
# # case where it is the very first thing to be done, in which case we do
# # have to go through one layer to get to the first feedforward.
# # the 'if' is there to rule out increasing when doing a feedforward on
# # nothing, ie, when making a constant. constants are allowed to be
# # created on layer 0, they're part of the embedding or the weights that
# # will use them later or whatever, it's fine
# # case where it is the very first thing to be done, in which case we do
# # have to go through one layer to get to the first feedforward.
# # the 'if' is there to rule out increasing when doing a feedforward on
# # nothing, ie, when making a constant. constants are allowed to be
# # created on layer 0, they're part of the embedding or the weights that
# # will use them later or whatever, it's fine

# at least as deep as needed MVs, but no deeper cause FF
# (which happens at end of layer)
return format_output(parents_tuple, parents2res, name,
min_poss_depth=min_poss_depth,
elementwise_function=elementwise_function,
from_zipmap=True)
from_zipmap=True)


def aggregate(select, sequences_tuple, elementwise_function=None,
Expand All @@ -574,7 +553,7 @@ def aggregate(select, sequences_tuple, elementwise_function=None,
def parents2res(s, vt): return _aggregate(
s, vt, elementwise_function, default=default)
def_uses = definitely_uses_identity_function

# at least as deep as needed attention and at least one deeper than needed
# MVs
return format_output(parents_tuple, parents2res, name,
Expand All @@ -583,7 +562,7 @@ def parents2res(s, vt): return _aggregate(
min_poss_depth=max(_min_poss_depth(
sequences_tuple)+1, select.min_poss_depth),
definitely_uses_identity_function=def_uses)


# up to here was just plain transformer 'assembly'. any addition is a lie
# now begin the bells and whistles
Expand Down
8 changes: 4 additions & 4 deletions RASP_support/REPL.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,19 @@ def print_named_val(self, name, val, ntabs=0, extra_first_pref=""):
optional_exampledesc = name + \
"("+formatstr(self.sequence_running_example)+") ="
print_seq(self.selector_running_example,
val(self.sequence_running_example),
val.call(self.sequence_running_example),
still_on_prev_line=True,
extra_pref=pref,
lastpref_if_shortprint=optional_exampledesc)
else:
print(pref, "\t Example:", name + "(" +
formatstr(self.sequence_running_example) + ") =",
val(self.sequence_running_example))
val.call(self.sequence_running_example))
elif isinstance(val, UnfinishedSelect):
print(pref, extra_first_pref, " selector:", name)
if self.show_selector_examples:
print(pref, "\t Example:")
print_select(self.selector_running_example, val(
print_select(self.selector_running_example, val.call(
self.selector_running_example), extra_pref=pref)
elif isinstance(val, RASPFunction):
print(pref, extra_first_pref, " "+str(val))
Expand Down Expand Up @@ -491,7 +491,7 @@ def get_input_tree(self):
if isinstance(newinput, Stop): # input stream ended
return Stop()
if is_comment(newinput):
# don't let comments get in and ruin things somehow
# don't let comments get in and ruin things somehow
newinput = ""
# don't replace newlines here! this is how in-function comments get
# broken
Expand Down
6 changes: 4 additions & 2 deletions RASP_support/Sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# top-level rasp file we import, and nice to have draw_comp_flow added into
# the sequences already on load


def _apply_unary_op(self, f):
return zipmap(self, f)

Expand Down Expand Up @@ -69,8 +70,9 @@ def asbool(seq):

def tplnot(seq, name=None):
# this one does correct conversion using asbool and then we really can just
# do ==False
res = asbool(seq) == False
# do == False
pep8hack = False # this avoids violating E712 of PEP8
res = asbool(seq) == pep8hack
return _addname(res, name, "( not " + str(seq.name) + " )")


Expand Down
4 changes: 2 additions & 2 deletions RASP_support/Support.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ def prep_default(default, num_output_vars):
verify_default_size(default, num_output_vars)
if not isinstance(default, tuple):
# specifically with how we're going to do things here in the
# average aggregate, will help to actually have the outputs get
# average aggregate, will help to actually have the outputs get
# passed around as tuples, even if they're scalars really.
# but do this after the size check for the scalar one so it doesn't
# get filled with weird ifs... this tupled scalar thing is only a
# get filled with weird ifs... this tupled scalar thing is only a
# convenience in this implementation in this here function
default = (default,)
return default
Expand Down
2 changes: 1 addition & 1 deletion RASP_support/analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def note_if_seeker(self):
return

if (not self.get_parent_sequences()) \
and (self.get_parent_select() is not None):
and (self.get_parent_select() is not None):
# no parent sequences, but yes parent select: this value is a function
# of only its parent select, i.e., a seeker (marks whether select found
# something or not)
Expand Down
2 changes: 1 addition & 1 deletion RASP_support/make_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __rpow__(self, other):
return apply_binary_op(self, other, lambda a, b: pow(b, a))

# skipping and, or, xor, which are bitwise and dont implement 'and' and
# 'or' but rather & and |.
# 'or' but rather & and |.
# similarly skipping lshift, rshift cause who wants them.
# wish i had not, and, or primitives, but can accept that dont.
# if people really want to do 'not' they can do '==False' instead, can do a
Expand Down
Loading

0 comments on commit 870d7a2

Please sign in to comment.