diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3e4dfd3..3c512d7 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -16,6 +16,7 @@ jobs: run: | python -m pip install --upgrade pip pip install antlr4-python3-runtime==4.9.1 + pip3 install termcolor - name: Run all the Python tests run: python3 tests/test_all.py pep8: @@ -29,4 +30,3 @@ jobs: with: arguments: >- --exclude=.svn,CVS,.bzr,.hg,.git,zzantlr - --ignore=E121,E123,E126,E133,E226,E241,E242,E704,W503,W504,W505,W191,E101,E128 diff --git a/RASP_support/DrawCompFlow.py b/RASP_support/DrawCompFlow.py index 6679f2a..fc5b0f2 100644 --- a/RASP_support/DrawCompFlow.py +++ b/RASP_support/DrawCompFlow.py @@ -1,5 +1,5 @@ from .FunctionalSupport import Unfinished, guarded_contains, base_tokens, \ - tokens_asis + tokens_asis from .Support import clean_val import os import string @@ -24,589 +24,589 @@ def windows_path_cleaner(s): - if os.name == "nt": # is windows - validchars = "-_.() "+string.ascii_letters+string.digits + if os.name == "nt": # is windows + validchars = "-_.() "+string.ascii_letters+string.digits - def fix(c): - return c if c in validchars else "." - return "".join([fix(c) for c in s]) - else: - return s + def fix(c): + return c if c in validchars else "." + return "".join([fix(c) for c in s]) + else: + return s def colour_scheme(row_type): - if row_type == INPUT: - return 'gray', 'gray', 'gray' - if row_type == QVAR: - return 'palegreen4', 'mediumseagreen', 'palegreen1' - elif row_type == KVAR: - return 'deepskyblue3', 'darkturquoise', 'darkslategray1' - elif row_type == VVAR: - return 'palevioletred3', 'palevioletred2', 'lightpink' - elif row_type == VREAL: - return 'plum4', 'plum3', 'thistle2' - elif row_type == RES: - return 'lightsalmon3', 'burlywood', 'burlywood1' - else: - raise Exception("unknown row type: "+str(row_type)) + if row_type == INPUT: + return 'gray', 'gray', 'gray' + if row_type == QVAR: + return 'palegreen4', 'mediumseagreen', 'palegreen1' + elif row_type == KVAR: + return 'deepskyblue3', 'darkturquoise', 'darkslategray1' + elif row_type == VVAR: + return 'palevioletred3', 'palevioletred2', 'lightpink' + elif row_type == VREAL: + return 'plum4', 'plum3', 'thistle2' + elif row_type == RES: + return 'lightsalmon3', 'burlywood', 'burlywood1' + else: + raise Exception("unknown row type: "+str(row_type)) QVAR, KVAR, VVAR, VREAL, RES, INPUT = [ - "QVAR", "KVAR", "VVAR", "VREAL", "RES", "INPUT"] + "QVAR", "KVAR", "VVAR", "VREAL", "RES", "INPUT"] POSS_ROWS = [QVAR, KVAR, VVAR, VREAL, RES, INPUT] ROW_NAMES = {QVAR: "Me", KVAR: "Other", VVAR: "X", - VREAL: "f(X)", RES: "FF", INPUT: ""} + VREAL: "f(X)", RES: "FF", INPUT: ""} def UnfinishedFunc(f): - setattr(Unfinished, f.__name__, f) + setattr(Unfinished, f.__name__, f) @UnfinishedFunc def last_val(self): - return self.last_res.get_vals() + return self.last_res.get_vals() def makeQKStable(qvars, kvars, select, ref_in_g): - qvars = [q.last_val() for q in qvars] - kvars = [k.last_val() for k in kvars] - select = select.last_val() - q_val_len, k_val_len = len(select), len(select[0]) - - qvars_skip = len(kvars) - kvars_skip = len(qvars) - _, _, qvars_colour = colour_scheme(QVAR) - _, _, kvars_colour = colour_scheme(KVAR) - # select has qvars along the rows and kvars along the columns, so we'll do - # the same. i.e. top rows will just be the kvars and first columns will - # just be the qvars. - # if (not qvars) and (not kvars): - # # no qvars or kvars -> full select -> dont waste space drawing. - # num_rows, num_columns = 0, 0 - # pass - # else: - # num_rows = qvars_skip+(len(qvars[0]) if qvars else 1) - # num_columns = kvars_skip+(len(kvars[0]) if kvars else 1) - num_rows = qvars_skip+q_val_len - num_columns = kvars_skip+k_val_len - - select_cells = {i: [CellVals('', head_color, j, i) - for j in range(num_columns)] - for i in range(num_rows)} - - for i, seq in enumerate(kvars): - for j, v in enumerate(seq): - vals = CellVals(v, kvars_colour, i, j+kvars_skip) - select_cells[i][j + kvars_skip] = vals - for j, seq in enumerate(qvars): - for i, v in enumerate(seq): - vals = CellVals(v, qvars_colour, i+qvars_skip, j) - select_cells[i + qvars_skip][j] = vals - - for i in range(num_rows-qvars_skip): # i goes over the q_var values - for j in range(num_columns-kvars_skip): # j goes over the k_var values - v = select[i][j] - colour = select_on_colour if v else select_off_colour - select_cells[i+qvars_skip][j+kvars_skip] = CellVals( - v, colour, i+qvars_skip, j+kvars_skip, select_internal=True) - - # TODO: make an ugly little q\k triangle thingy in the top corner - return GridTable(select_cells, ref_in_g) + qvars = [q.last_val() for q in qvars] + kvars = [k.last_val() for k in kvars] + select = select.last_val() + q_val_len, k_val_len = len(select), len(select[0]) + + qvars_skip = len(kvars) + kvars_skip = len(qvars) + _, _, qvars_colour = colour_scheme(QVAR) + _, _, kvars_colour = colour_scheme(KVAR) + # select has qvars along the rows and kvars along the columns, so we'll do + # the same. i.e. top rows will just be the kvars and first columns will + # just be the qvars. + # if (not qvars) and (not kvars): + # # no qvars or kvars -> full select -> dont waste space drawing. + # num_rows, num_columns = 0, 0 + # pass + # else: + # num_rows = qvars_skip+(len(qvars[0]) if qvars else 1) + # num_columns = kvars_skip+(len(kvars[0]) if kvars else 1) + num_rows = qvars_skip+q_val_len + num_columns = kvars_skip+k_val_len + + select_cells = {i: [CellVals('', head_color, j, i) + for j in range(num_columns)] + for i in range(num_rows)} + + for i, seq in enumerate(kvars): + for j, v in enumerate(seq): + vals = CellVals(v, kvars_colour, i, j+kvars_skip) + select_cells[i][j + kvars_skip] = vals + for j, seq in enumerate(qvars): + for i, v in enumerate(seq): + vals = CellVals(v, qvars_colour, i+qvars_skip, j) + select_cells[i + qvars_skip][j] = vals + + for i in range(num_rows-qvars_skip): # i goes over the q_var values + for j in range(num_columns-kvars_skip): # j goes over the k_var values + v = select[i][j] + colour = select_on_colour if v else select_off_colour + select_cells[i+qvars_skip][j+kvars_skip] = CellVals( + v, colour, i+qvars_skip, j+kvars_skip, select_internal=True) + + # TODO: make an ugly little q\k triangle thingy in the top corner + return GridTable(select_cells, ref_in_g) class CellVals: - def __init__(self, val, colour, i_row, i_col, select_internal=False, - known_portstr=None): - def mystr(v): - if isinstance(v, bool): - if select_internal: - return ' ' if v else ' ' # color gives it all! - else: - return 'T' if v else 'F' - if isinstance(v, float): - v = clean_val(v, 3) - if isinstance(v, int) and len(str(v)) == 1: - v = " "+str(v) # for pretty square selectors - return str(v).replace("<", "<").replace(">", ">") - self.val = mystr(val) - self.colour = colour - if None is known_portstr: - self.portstr = "_col"+str(i_col)+"_row"+str(i_row) - else: - self.portstr = known_portstr - - def __str__(self): - return '' + self.val+'' + def __init__(self, val, colour, i_row, i_col, select_internal=False, + known_portstr=None): + def mystr(v): + if isinstance(v, bool): + if select_internal: + return ' ' if v else ' ' # color gives it all! + else: + return 'T' if v else 'F' + if isinstance(v, float): + v = clean_val(v, 3) + if isinstance(v, int) and len(str(v)) == 1: + v = " "+str(v) # for pretty square selectors + return str(v).replace("<", "<").replace(">", ">") + self.val = mystr(val) + self.colour = colour + if None is known_portstr: + self.portstr = "_col"+str(i_col)+"_row"+str(i_row) + else: + self.portstr = known_portstr + + def __str__(self): + return '' + self.val+'' class GridTable: - def __init__(self, cellvals, ref_in_g): - self.ref_in_g = ref_in_g - self.cellvals = cellvals - self.numcols = len(cellvals.get(0, [])) - self.numrows = len(cellvals) - self.empty = 0 in [self.numcols, self.numrows] + def __init__(self, cellvals, ref_in_g): + self.ref_in_g = ref_in_g + self.cellvals = cellvals + self.numcols = len(cellvals.get(0, [])) + self.numrows = len(cellvals) + self.empty = 0 in [self.numcols, self.numrows] - def to_str(self, transposed=False): - ii = sorted(list(self.cellvals.keys())) - rows = [self.cellvals[i] for i in ii] + def to_str(self, transposed=False): + ii = sorted(list(self.cellvals.keys())) + rows = [self.cellvals[i] for i in ii] - def cells2row(cells): - return ''+''.join(map(str, cells))+'' - return '<' + ''.join(map(cells2row, rows)) \ - + '
>' + def cells2row(cells): + return ''+''.join(map(str, cells))+'' + return '<' + ''.join(map(cells2row, rows)) \ + + '
>' - def bottom_left_portstr(self): - return self.access_portstr(0, -1) + def bottom_left_portstr(self): + return self.access_portstr(0, -1) - def bottom_right_portstr(self): - return self.access_portstr(-1, -1) + def bottom_right_portstr(self): + return self.access_portstr(-1, -1) - def top_left_portstr(self): - return self.access_portstr(0, 0) + def top_left_portstr(self): + return self.access_portstr(0, 0) - def top_right_portstr(self): - return self.access_portstr(-1, 0) + def top_right_portstr(self): + return self.access_portstr(-1, 0) - def top_access_portstr(self, i_col): - return self.access_portstr(i_col, 0) + def top_access_portstr(self, i_col): + return self.access_portstr(i_col, 0) - def bottom_access_portstr(self, i_col): - return self.access_portstr(i_col, -1) + def bottom_access_portstr(self, i_col): + return self.access_portstr(i_col, -1) - def access_portstr(self, i_col, i_row): - return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row) + def access_portstr(self, i_col, i_row): + return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row) - def internal_portstr(self, i_col, i_row): - if i_col < 0: - i_col = self.numcols + i_col - if i_row < 0: - i_row = self.numrows + i_row - return "_col"+str(i_col)+"_row"+str(i_row) + def internal_portstr(self, i_col, i_row): + if i_col < 0: + i_col = self.numcols + i_col + if i_row < 0: + i_row = self.numrows + i_row + return "_col"+str(i_col)+"_row"+str(i_row) - def add_to_graph(self, g): - if self.empty: - pass - else: - g.node(name=self.ref_in_g, shape='none', - margin='0', label=self.to_str()) + def add_to_graph(self, g): + if self.empty: + pass + else: + g.node(name=self.ref_in_g, shape='none', + margin='0', label=self.to_str()) class Table: - def __init__(self, seqs_by_rowtype, ref_in_g, rowtype_order=[]): - self.ref_in_g = ref_in_g - # consistent presentation, and v useful for feedforward clarity - self.rows = [] - self.seq_index = {} - if len(rowtype_order) > 1: - self.add_rowtype_cell = True - else: - errnote = "table got multiple row types but no order for them" - assert len(seqs_by_rowtype.keys()) == 1, errnote - rowtype_order = list(seqs_by_rowtype.keys()) - self.add_rowtype_cell = not (rowtype_order[0] == RES) - self.note_res_dependencies = len(seqs_by_rowtype.get(RES, [])) > 1 - self.leading_metadata_offset = 1 + self.add_rowtype_cell - for rt in rowtype_order: - seqs = sorted(seqs_by_rowtype[rt], - key=lambda seq: seq.creation_order_id) - for i, seq in enumerate(seqs): - # each one appends to self.rows. - self.n = self.add_row(seq, rt) - # self.n stores length of a single row, they will all be the - # same, just easiest to get like this - # add_row has to happen one at a time b/c they care about - # length of self.rows at time of addition (to get ports right) - self.empty = len(self.rows) == 0 - if self.empty: - self.n = 0 - # (len(rowtype_order)==1 and rowtype_order[0]==QVAR) - self.transpose = False - # no need to twist Q, just making the table under anyway - # transpose affects the port accesses, but think about that later - - def to_str(self): - rows = self.rows if not self.transpose else list(zip(*self.rows)) - - def cells2row(cells): - return ''+''.join(cells)+'' - return '<' + ''.join(map(cells2row, rows)) \ - + '
>' - - def bottom_left_portstr(self): - return self.access_portstr(0, -1) - - def bottom_right_portstr(self): - return self.access_portstr(-1, -1) - - def top_left_portstr(self): - return self.access_portstr(0, 0) - - def top_right_portstr(self): - return self.access_portstr(-1, 0) - - def top_access_portstr(self, i_col, skip_meta=False): - return self.access_portstr(i_col, 0, skip_meta=skip_meta) - - def bottom_access_portstr(self, i_col, skip_meta=False): - return self.access_portstr(i_col, -1, skip_meta=skip_meta) - - def access_portstr(self, i_col, i_row, skip_meta=False): - return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row, - skip_meta=skip_meta) - - def internal_portstr(self, i_col, i_row, skip_meta=False): - if skip_meta and (i_col >= 0): # before flip things for reverse column - # access - i_col += self.leading_metadata_offset - if i_col < 0: - i_col = (self.n) + i_col - if i_row < 0: - i_row = len(self.rows) + i_row - return "_col"+str(i_col)+"_row"+str(i_row) - - def add_row(self, seq, row_type): - def add_cell(val, colour): - res = CellVals(val, colour, -1, -1, - known_portstr=self.internal_portstr(len(cells), - len(self.rows))) - cells.append(str(res)) - - def add_strong_line(): - # after failing to inject css styles in graphviz, - # seeing that their suggestion only creates lines - # (if at all? unclear) of width 1 - # (same as the border already there) and it wont make multiple VRs, - # and realising their suggestion also does nothing, - # refer to hack at the top of this priceless page: - # http://jkorpela.fi/html/cellborder.html - cells.append('') - - qkvr_colour, name_colour, data_colour = colour_scheme(row_type) - cells = [] # has to be created in advance, and not just be all the - # results of add_cell, because add_cell cares about current length of - # 'cells' - if self.add_rowtype_cell: - add_cell(ROW_NAMES[row_type], qkvr_colour) - add_cell(seq.name, name_colour) - for v in seq.last_val(): - add_cell(v, data_colour) - if self.note_res_dependencies: - self.seq_index[seq] = len(self.rows) - add_strong_line() - add_cell("("+str(self.seq_index[seq])+")", indices_colour) - add_cell(self.dependencies_str(seq, row_type), comment_colour) - self.rows.append(cells) - return len(cells) - - def dependencies_str(self, seq, row_type): - if not row_type == RES: - return "" - return "from ("+", ".join(str(self.seq_index[m]) for m in - seq.get_nonminor_parent_sequences()) + ")" - - def add_to_graph(self, g): - if self.empty: - # g.node(name=self.ref_in_g,label="empty table") - pass - else: - g.node(name=self.ref_in_g, shape='none', - margin='0', label=self.to_str()) + def __init__(self, seqs_by_rowtype, ref_in_g, rowtype_order=[]): + self.ref_in_g = ref_in_g + # consistent presentation, and v useful for feedforward clarity + self.rows = [] + self.seq_index = {} + if len(rowtype_order) > 1: + self.add_rowtype_cell = True + else: + errnote = "table got multiple row types but no order for them" + assert len(seqs_by_rowtype.keys()) == 1, errnote + rowtype_order = list(seqs_by_rowtype.keys()) + self.add_rowtype_cell = not (rowtype_order[0] == RES) + self.note_res_dependencies = len(seqs_by_rowtype.get(RES, [])) > 1 + self.leading_metadata_offset = 1 + self.add_rowtype_cell + for rt in rowtype_order: + seqs = sorted(seqs_by_rowtype[rt], + key=lambda seq: seq.creation_order_id) + for i, seq in enumerate(seqs): + # each one appends to self.rows. + self.n = self.add_row(seq, rt) + # self.n stores length of a single row, they will all be the + # same, just easiest to get like this + # add_row has to happen one at a time b/c they care about + # length of self.rows at time of addition (to get ports right) + self.empty = len(self.rows) == 0 + if self.empty: + self.n = 0 + # (len(rowtype_order)==1 and rowtype_order[0]==QVAR) + self.transpose = False + # no need to twist Q, just making the table under anyway + # transpose affects the port accesses, but think about that later + + def to_str(self): + rows = self.rows if not self.transpose else list(zip(*self.rows)) + + def cells2row(cells): + return ''+''.join(cells)+'' + return '<' + ''.join(map(cells2row, rows)) \ + + '
>' + + def bottom_left_portstr(self): + return self.access_portstr(0, -1) + + def bottom_right_portstr(self): + return self.access_portstr(-1, -1) + + def top_left_portstr(self): + return self.access_portstr(0, 0) + + def top_right_portstr(self): + return self.access_portstr(-1, 0) + + def top_access_portstr(self, i_col, skip_meta=False): + return self.access_portstr(i_col, 0, skip_meta=skip_meta) + + def bottom_access_portstr(self, i_col, skip_meta=False): + return self.access_portstr(i_col, -1, skip_meta=skip_meta) + + def access_portstr(self, i_col, i_row, skip_meta=False): + return self.ref_in_g + ":" + self.internal_portstr(i_col, i_row, + skip_meta=skip_meta) + + def internal_portstr(self, i_col, i_row, skip_meta=False): + if skip_meta and (i_col >= 0): # before flip things for reverse column + # access + i_col += self.leading_metadata_offset + if i_col < 0: + i_col = (self.n) + i_col + if i_row < 0: + i_row = len(self.rows) + i_row + return "_col"+str(i_col)+"_row"+str(i_row) + + def add_row(self, seq, row_type): + def add_cell(val, colour): + res = CellVals(val, colour, -1, -1, + known_portstr=self.internal_portstr(len(cells), + len(self.rows))) + cells.append(str(res)) + + def add_strong_line(): + # after failing to inject css styles in graphviz, + # seeing that their suggestion only creates lines + # (if at all? unclear) of width 1 + # (same as the border already there) and it wont make multiple VRs, + # and realising their suggestion also does nothing, + # refer to hack at the top of this priceless page: + # http://jkorpela.fi/html/cellborder.html + cells.append('') + + qkvr_colour, name_colour, data_colour = colour_scheme(row_type) + cells = [] # has to be created in advance, and not just be all the + # results of add_cell, because add_cell cares about current length of + # 'cells' + if self.add_rowtype_cell: + add_cell(ROW_NAMES[row_type], qkvr_colour) + add_cell(seq.name, name_colour) + for v in seq.last_val(): + add_cell(v, data_colour) + if self.note_res_dependencies: + self.seq_index[seq] = len(self.rows) + add_strong_line() + add_cell("("+str(self.seq_index[seq])+")", indices_colour) + add_cell(self.dependencies_str(seq, row_type), comment_colour) + self.rows.append(cells) + return len(cells) + + def dependencies_str(self, seq, row_type): + if not row_type == RES: + return "" + return "from ("+", ".join(str(self.seq_index[m]) for m in + seq.get_nonminor_parent_sequences()) + ")" + + def add_to_graph(self, g): + if self.empty: + # g.node(name=self.ref_in_g,label="empty table") + pass + else: + g.node(name=self.ref_in_g, shape='none', + margin='0', label=self.to_str()) def place_above(g, node1, node2): - g.edge(node1.bottom_left_portstr(), node2.top_left_portstr(), - style="invis") - g.edge(node1.bottom_right_portstr(), - node2.top_right_portstr(), style="invis") + g.edge(node1.bottom_left_portstr(), node2.top_left_portstr(), + style="invis") + g.edge(node1.bottom_right_portstr(), + node2.top_right_portstr(), style="invis") def connect(g, top_table, bottom_table, select_vals): - # connects top_table as k and bottom_table as q - if top_table.empty or bottom_table.empty: - return # not doing this for now - place_above(g, top_table, bottom_table) - # just to position them one on top of the other, even if select is empty - for q_i in select_vals: - for k_i, b in enumerate(select_vals[q_i]): - if b: - # have to add 2 cause first 2 are data type and row name - g.edge(top_table.bottom_access_portstr(k_i, skip_meta=True), - bottom_table.top_access_portstr(q_i, skip_meta=True), - arrowhead='none') + # connects top_table as k and bottom_table as q + if top_table.empty or bottom_table.empty: + return # not doing this for now + place_above(g, top_table, bottom_table) + # just to position them one on top of the other, even if select is empty + for q_i in select_vals: + for k_i, b in enumerate(select_vals[q_i]): + if b: + # have to add 2 cause first 2 are data type and row name + g.edge(top_table.bottom_access_portstr(k_i, skip_meta=True), + bottom_table.top_access_portstr(q_i, skip_meta=True), + arrowhead='none') class SubHead: - def __init__(self, name, seq): - vvars = seq.get_immediate_parent_sequences() - if not seq.definitely_uses_identity_function: - vreal = seq.pre_aggregate_comp() - vreal(seq.last_w) # run it on same w to fill with right results - vreals = [vreal] - else: - vreals = [] - - self.name = name - self.vvars_table = Table( - {VVAR: vvars, VREAL: vreals}, self.name+"_vvars", - rowtype_order=[VVAR, VREAL]) - self.res_table = Table({RES: [seq]}, self.name+"_res") - self.default = "default: " + \ - str(seq.default) if seq.default is not None else "" - # self.vreals_table = ## ? add partly processed vals, useful for eg - # conditioned_contains? - - def add_to_graph(self, g): - self.vvars_table.add_to_graph(g) - self.res_table.add_to_graph(g) - if self.default: - g.node(self.name+"_default", shape='rectangle', label=self.default) - g.edge(self.name+"_default", self.res_table.top_left_portstr(), - arrowhead='none') - - def add_edges(self, g, select_vals): - connect(g, self.vvars_table, self.res_table, select_vals) - - def bottom_left_portstr(self): - return self.res_table.bottom_left_portstr() - - def bottom_right_portstr(self): - return self.res_table.bottom_right_portstr() - - def top_left_portstr(self): - return self.vvars_table.top_left_portstr() - - def top_right_portstr(self): - return self.vvars_table.top_right_portstr() + def __init__(self, name, seq): + vvars = seq.get_immediate_parent_sequences() + if not seq.definitely_uses_identity_function: + vreal = seq.pre_aggregate_comp() + vreal(seq.last_w) # run it on same w to fill with right results + vreals = [vreal] + else: + vreals = [] + + self.name = name + self.vvars_table = Table( + {VVAR: vvars, VREAL: vreals}, self.name+"_vvars", + rowtype_order=[VVAR, VREAL]) + self.res_table = Table({RES: [seq]}, self.name+"_res") + self.default = "default: " + \ + str(seq.default) if seq.default is not None else "" + # self.vreals_table = ## ? add partly processed vals, useful for eg + # conditioned_contains? + + def add_to_graph(self, g): + self.vvars_table.add_to_graph(g) + self.res_table.add_to_graph(g) + if self.default: + g.node(self.name+"_default", shape='rectangle', label=self.default) + g.edge(self.name+"_default", self.res_table.top_left_portstr(), + arrowhead='none') + + def add_edges(self, g, select_vals): + connect(g, self.vvars_table, self.res_table, select_vals) + + def bottom_left_portstr(self): + return self.res_table.bottom_left_portstr() + + def bottom_right_portstr(self): + return self.res_table.bottom_right_portstr() + + def top_left_portstr(self): + return self.vvars_table.top_left_portstr() + + def top_right_portstr(self): + return self.vvars_table.top_right_portstr() class Head: - def __init__(self, name, head_primitives, i): - self.name = name - self.i = i - self.head_primitives = head_primitives - select = self.head_primitives.select - q_vars, k_vars = select.q_vars, select.k_vars - q_vars = sorted(list(set(q_vars)), key=lambda a: a.creation_order_id) - k_vars = sorted(list(set(k_vars)), key=lambda a: a.creation_order_id) - self.kq_table = Table({QVAR: q_vars, KVAR: k_vars}, - self.name+"_qvars", rowtype_order=[KVAR, QVAR]) - # self.k_table = Table({KVAR:k_vars},self.name+"_kvars") - self.select_result_table = makeQKStable( - q_vars, k_vars, select, self.name+"_select") - # self.select_table = SelectTable(self.head_primitives.select, - # self.name+"_select") - self.subheads = [SubHead(self.name+"_subcomp_"+str(i), seq) - for i, seq in - enumerate(self.head_primitives.sequences)] - - def add_to_graph(self, g): - with g.subgraph(name=self.name) as head: - def headlabel(): - # return self.head_primitives.select.name - return 'head '+str(self.i) +\ - "\n("+self.head_primitives.select.name+")" - head.attr(fillcolor=head_color, label=headlabel(), - fontcolor='black', style='filled') - with head.subgraph(name=self.name+"_select_parts") as sel: - sel.attr(rankdir="LR", label="", style="invis", rank="same") - if True: # not (self.kq_table.empty): - self.select_result_table.add_to_graph(sel) - self.kq_table.add_to_graph(sel) - # sel.edge(self.kq_table.bottom_right_portstr(), - # self.select_result_table.bottom_left_portstr(),style="invis") - - [s.add_to_graph(head) for s in self.subheads] - - def add_organising_edges(self, g): - if self.kq_table.empty: - return - for s in self.subheads: - place_above(g, self.select_result_table, s) - - def bottom_left_portstr(self): - return self.subheads[0].bottom_left_portstr() - - def bottom_right_portstr(self): - return self.subheads[-1].bottom_right_portstr() - - def top_left_portstr(self): - if not (self.kq_table.empty): - return self.kq_table.top_left_portstr() - else: # no kq (and so no select either) table. go into subheads - return self.subheads[0].top_left_portstr() - - def top_right_portstr(self): - if not (self.kq_table.empty): - return self.kq_table.top_right_portstr() - else: - return self.subheads[-1].top_right_portstr() - - def add_edges(self, g): - select_vals = self.head_primitives.select.last_val() - # connect(g,self.k_table,self.q_table,select_vals) - for s in self.subheads: - s.add_edges(g, select_vals) - self.add_organising_edges(g) + def __init__(self, name, head_primitives, i): + self.name = name + self.i = i + self.head_primitives = head_primitives + select = self.head_primitives.select + q_vars, k_vars = select.q_vars, select.k_vars + q_vars = sorted(list(set(q_vars)), key=lambda a: a.creation_order_id) + k_vars = sorted(list(set(k_vars)), key=lambda a: a.creation_order_id) + self.kq_table = Table({QVAR: q_vars, KVAR: k_vars}, + self.name+"_qvars", rowtype_order=[KVAR, QVAR]) + # self.k_table = Table({KVAR:k_vars},self.name+"_kvars") + self.select_result_table = makeQKStable( + q_vars, k_vars, select, self.name+"_select") + # self.select_table = SelectTable(self.head_primitives.select, + # self.name+"_select") + self.subheads = [SubHead(self.name+"_subcomp_"+str(i), seq) + for i, seq in + enumerate(self.head_primitives.sequences)] + + def add_to_graph(self, g): + with g.subgraph(name=self.name) as head: + def headlabel(): + # return self.head_primitives.select.name + return 'head '+str(self.i) +\ + "\n("+self.head_primitives.select.name+")" + head.attr(fillcolor=head_color, label=headlabel(), + fontcolor='black', style='filled') + with head.subgraph(name=self.name+"_select_parts") as sel: + sel.attr(rankdir="LR", label="", style="invis", rank="same") + if True: # not (self.kq_table.empty): + self.select_result_table.add_to_graph(sel) + self.kq_table.add_to_graph(sel) + # sel.edge(self.kq_table.bottom_right_portstr(), + # self.select_result_table.bottom_left_portstr(),style="invis") + + [s.add_to_graph(head) for s in self.subheads] + + def add_organising_edges(self, g): + if self.kq_table.empty: + return + for s in self.subheads: + place_above(g, self.select_result_table, s) + + def bottom_left_portstr(self): + return self.subheads[0].bottom_left_portstr() + + def bottom_right_portstr(self): + return self.subheads[-1].bottom_right_portstr() + + def top_left_portstr(self): + if not (self.kq_table.empty): + return self.kq_table.top_left_portstr() + else: # no kq (and so no select either) table. go into subheads + return self.subheads[0].top_left_portstr() + + def top_right_portstr(self): + if not (self.kq_table.empty): + return self.kq_table.top_right_portstr() + else: + return self.subheads[-1].top_right_portstr() + + def add_edges(self, g): + select_vals = self.head_primitives.select.last_val() + # connect(g,self.k_table,self.q_table,select_vals) + for s in self.subheads: + s.add_edges(g, select_vals) + self.add_organising_edges(g) def contains_tokens(mvs): - return next((True for mv in mvs if guarded_contains(base_tokens, mv)), - False) + return next((True for mv in mvs if guarded_contains(base_tokens, mv)), + False) def just_base_sequence_fix(d_ffs, ff_parents): - # when there are no parents and only one ff, then we are actually just - # looking at the indices/tokens by themselves. in this case, putting that - # ff in as a parent (with no child) makes the layer draw it properly - if not ff_parents and len(d_ffs) == 1: - return ff_parents, d_ffs - return d_ffs, ff_parents + # when there are no parents and only one ff, then we are actually just + # looking at the indices/tokens by themselves. in this case, putting that + # ff in as a parent (with no child) makes the layer draw it properly + if not ff_parents and len(d_ffs) == 1: + return ff_parents, d_ffs + return d_ffs, ff_parents class Layer: - def __init__(self, depth, d_heads, d_ffs, add_tokens_on_ff=False): - self.heads = [] - self.depth = depth - self.name = self.layer_cluster_name(depth) - for i, h in enumerate(d_heads): - self.heads.append(Head(self.name+"_head"+str(i), h, i)) - ff_parents = [] - for ff in d_ffs: - ff_parents += ff.get_nonminor_parent_sequences() - ff_parents = list(set(ff_parents)) - ff_parents = [p for p in ff_parents if not guarded_contains(d_ffs, p)] - d_ffs, ff_parents = just_base_sequence_fix(d_ffs, ff_parents) - rows_by_type = {RES: d_ffs, VVAR: ff_parents} - rowtype_order = [VVAR, RES] - if add_tokens_on_ff and not contains_tokens(ff_parents): - rows_by_type[INPUT] = [tokens_asis] - rowtype_order = [INPUT] + rowtype_order - self.ff_table = Table(rows_by_type, self.name+"_ffs", rowtype_order) - - def bottom_object(self): - if not self.ff_table.empty: - return self.ff_table - else: - return self.heads[-1] - - def top_object(self): - if self.heads: - return self.heads[0] - else: - return self.ff_table - - def bottom_left_portstr(self): - return self.bottom_object().bottom_left_portstr() - - def bottom_right_portstr(self): - return self.bottom_object().bottom_right_portstr() - - def top_left_portstr(self): - return self.top_object().top_left_portstr() - - def top_right_portstr(self): - return self.top_object().top_right_portstr() - - def add_to_graph(self, g): - with g.subgraph(name=self.name) as lg: - lg.attr(fillcolor=layer_color, label='layer '+str(self.depth), - fontcolor='black', style='filled') - for h in self.heads: - h.add_to_graph(lg) - self.ff_table.add_to_graph(lg) - - def add_organising_edges(self, g): - if self.ff_table.empty: - return - for h in self.heads: - place_above(g, h, self.ff_table) - - def add_edges(self, g): - for h in self.heads: - h.add_edges(g) - self.add_organising_edges(g) - - def layer_cluster_name(self, depth): - return 'cluster_l'+str(depth) # graphviz needs - # cluster names to start with 'cluster' + def __init__(self, depth, d_heads, d_ffs, add_tokens_on_ff=False): + self.heads = [] + self.depth = depth + self.name = self.layer_cluster_name(depth) + for i, h in enumerate(d_heads): + self.heads.append(Head(self.name+"_head"+str(i), h, i)) + ff_parents = [] + for ff in d_ffs: + ff_parents += ff.get_nonminor_parent_sequences() + ff_parents = list(set(ff_parents)) + ff_parents = [p for p in ff_parents if not guarded_contains(d_ffs, p)] + d_ffs, ff_parents = just_base_sequence_fix(d_ffs, ff_parents) + rows_by_type = {RES: d_ffs, VVAR: ff_parents} + rowtype_order = [VVAR, RES] + if add_tokens_on_ff and not contains_tokens(ff_parents): + rows_by_type[INPUT] = [tokens_asis] + rowtype_order = [INPUT] + rowtype_order + self.ff_table = Table(rows_by_type, self.name+"_ffs", rowtype_order) + + def bottom_object(self): + if not self.ff_table.empty: + return self.ff_table + else: + return self.heads[-1] + + def top_object(self): + if self.heads: + return self.heads[0] + else: + return self.ff_table + + def bottom_left_portstr(self): + return self.bottom_object().bottom_left_portstr() + + def bottom_right_portstr(self): + return self.bottom_object().bottom_right_portstr() + + def top_left_portstr(self): + return self.top_object().top_left_portstr() + + def top_right_portstr(self): + return self.top_object().top_right_portstr() + + def add_to_graph(self, g): + with g.subgraph(name=self.name) as lg: + lg.attr(fillcolor=layer_color, label='layer ' + str(self.depth), + fontcolor='black', style='filled') + for h in self.heads: + h.add_to_graph(lg) + self.ff_table.add_to_graph(lg) + + def add_organising_edges(self, g): + if self.ff_table.empty: + return + for h in self.heads: + place_above(g, h, self.ff_table) + + def add_edges(self, g): + for h in self.heads: + h.add_edges(g) + self.add_organising_edges(g) + + def layer_cluster_name(self, depth): + return 'cluster_l'+str(depth) # graphviz needs + # cluster names to start with 'cluster' class CompFlow: - def __init__(self, all_heads, all_ffs, force_vertical_layers, - add_tokens_on_ff=False): - self.force_vertical_layers = force_vertical_layers - self.add_tokens_on_ff = add_tokens_on_ff - self.make_all_layers(all_heads, all_ffs) - - def make_all_layers(self, all_heads, all_ffs): - self.layers = [] - ff_depths = [seq.scheduled_comp_depth for seq in all_ffs] - head_depths = [h.comp_depth for h in all_heads] - depths = sorted(list(set(ff_depths+head_depths))) - for d in depths: - d_heads = [h for h in all_heads if h.comp_depth == d] - d_heads = sorted(d_heads, key=lambda h: h.select.creation_order_id) - # only important for determinism to help debug - d_ffs = [f for f in all_ffs if f.scheduled_comp_depth == d] - self.layers.append(Layer(d, d_heads, d_ffs, self.add_tokens_on_ff)) - - def add_all_layers(self, g): - [layer.add_to_graph(g) for layer in self.layers] - - def add_organising_edges(self, g): - if self.force_vertical_layers: - for l1, l2 in zip(self.layers, self.layers[1:]): - place_above(g, l1, l2) - - def add_edges(self, g): - self.add_organising_edges(g) - [layer.add_edges(g) for layer in self.layers] + def __init__(self, all_heads, all_ffs, force_vertical_layers, + add_tokens_on_ff=False): + self.force_vertical_layers = force_vertical_layers + self.add_tokens_on_ff = add_tokens_on_ff + self.make_all_layers(all_heads, all_ffs) + + def make_all_layers(self, all_heads, all_ffs): + self.layers = [] + ff_depths = [seq.scheduled_comp_depth for seq in all_ffs] + head_depths = [h.comp_depth for h in all_heads] + depths = sorted(list(set(ff_depths+head_depths))) + for d in depths: + d_heads = [h for h in all_heads if h.comp_depth == d] + d_heads = sorted(d_heads, key=lambda h: h.select.creation_order_id) + # only important for determinism to help debug + d_ffs = [f for f in all_ffs if f.scheduled_comp_depth == d] + self.layers.append(Layer(d, d_heads, d_ffs, self.add_tokens_on_ff)) + + def add_all_layers(self, g): + [layer.add_to_graph(g) for layer in self.layers] + + def add_organising_edges(self, g): + if self.force_vertical_layers: + for l1, l2 in zip(self.layers, self.layers[1:]): + place_above(g, l1, l2) + + def add_edges(self, g): + self.add_organising_edges(g) + [layer.add_edges(g) for layer in self.layers] @UnfinishedFunc def draw_comp_flow(self, w, filename=None, - keep_dot=False, show=True, - force_vertical_layers=True, add_tokens_on_ff=False): - if w is not None: - self.call(w) # execute seq (and all its ancestors) on the given input - if not self.last_w == w: - print("evaluating input failed") - return - else: # if w == None, assume seq has already been executed on some input. - w = self.last_w - if None is filename: - name = self.name - filename = os.path.join("comp_flows", windows_path_cleaner( - name+"("+(str(w) if not isinstance(w, str) else "\""+w+"\"")+")")) - self.mark_all_minor_ancestors() - self.make_display_names_for_all_parents(skip_minors=True) - - all_heads, all_ffs = self.get_all_ancestor_heads_and_ffs( - remove_minors=True) - # this scheduling also marks the analysis parent selects - compflow = CompFlow(all_heads, all_ffs, - force_vertical_layers=force_vertical_layers, - add_tokens_on_ff=add_tokens_on_ff) - - # only import graphviz *inside* this function - - # that way RASP can run even if graphviz setup fails - # (though it will not be able to draw computation flows without it) - from graphviz import Digraph - g = Digraph('g') - # with curved lines it fusses over separating score edges - # and makes weirdly curved ones that start overlapping with the sequences - # :( - g.attr(splines='polyline') - compflow.add_all_layers(g) - compflow.add_edges(g) - g.render(filename=filename) - if show: - g.view() - if not keep_dot: - os.remove(filename) + keep_dot=False, show=True, + force_vertical_layers=True, add_tokens_on_ff=False): + if w is not None: + self.call(w) # execute seq (and all its ancestors) on the given input + if not self.last_w == w: + print("evaluating input failed") + return + else: # if w == None, assume seq has already been executed on some input. + w = self.last_w + if None is filename: + name = self.name + filename = os.path.join("comp_flows", windows_path_cleaner( + name+"("+(str(w) if not isinstance(w, str) else "\""+w+"\"")+")")) + self.mark_all_minor_ancestors() + self.make_display_names_for_all_parents(skip_minors=True) + + all_heads, all_ffs = self.get_all_ancestor_heads_and_ffs( + remove_minors=True) + # this scheduling also marks the analysis parent selects + compflow = CompFlow(all_heads, all_ffs, + force_vertical_layers=force_vertical_layers, + add_tokens_on_ff=add_tokens_on_ff) + + # only import graphviz *inside* this function - + # that way RASP can run even if graphviz setup fails + # (though it will not be able to draw computation flows without it) + from graphviz import Digraph + g = Digraph('g') + # with curved lines it fusses over separating score edges + # and makes weirdly curved ones that start overlapping with the sequences + # :( + g.attr(splines='polyline') + compflow.add_all_layers(g) + compflow.add_edges(g) + g.render(filename=filename) + if show: + g.view() + if not keep_dot: + os.remove(filename) dummyimport = None diff --git a/RASP_support/Environment.py b/RASP_support/Environment.py index 0dddc1e..e8896a8 100644 --- a/RASP_support/Environment.py +++ b/RASP_support/Environment.py @@ -1,97 +1,97 @@ from .FunctionalSupport import Unfinished, RASPTypeError, tokens_asis, \ - tokens_str, tokens_int, tokens_bool, tokens_float, indices + tokens_str, tokens_int, tokens_bool, tokens_float, indices from .Evaluator import RASPFunction class UndefinedVariable(Exception): - def __init__(self, varname): - super().__init__("Error: Undefined variable: "+varname) + def __init__(self, varname): + super().__init__("Error: Undefined variable: "+varname) class ReservedName(Exception): - def __init__(self, varname): - super().__init__("Error: Cannot set reserved name: "+varname) + def __init__(self, varname): + super().__init__("Error: Cannot set reserved name: "+varname) class Environment: - def __init__(self, parent_env=None, name=None, stealing_env=None): - self.variables = {} - self.name = name - self.parent_env = parent_env - self.stealing_env = stealing_env - self.base_setup() # nested envs can have them too. makes life simpler, - # instead of checking if they have the constant_variables etc in get. - # bit heavier on memory but no one's going to use this language for big - # nested stuff anyway - self.storing_in_constants = False + def __init__(self, parent_env=None, name=None, stealing_env=None): + self.variables = {} + self.name = name + self.parent_env = parent_env + self.stealing_env = stealing_env + self.base_setup() # nested envs can have them too. makes life simpler, + # instead of checking if they have the constant_variables etc in get. + # bit heavier on memory but no one's going to use this language for big + # nested stuff anyway + self.storing_in_constants = False - def base_setup(self): - self.constant_variables = {"tokens_asis": tokens_asis, - "tokens_str": tokens_str, - "tokens_int": tokens_int, - "tokens_bool": tokens_bool, - "tokens_float": tokens_float, - "indices": indices, - "True": True, - "False": False} - self.reserved_words = ["if", "else", "not", "and", "or", "out", "def", - "return", "range", "for", "in", "zip", "len", - "get"] + list(self.constant_variables.keys()) + def base_setup(self): + self.constant_variables = {"tokens_asis": tokens_asis, + "tokens_str": tokens_str, + "tokens_int": tokens_int, + "tokens_bool": tokens_bool, + "tokens_float": tokens_float, + "indices": indices, + "True": True, + "False": False} + self.reserved_words = ["if", "else", "not", "and", "or", "out", "def", + "return", "range", "for", "in", "zip", "len", + "get"] + list(self.constant_variables.keys()) - def snapshot(self): - res = Environment(parent_env=self.parent_env, - name=self.name, stealing_env=self.stealing_env) + def snapshot(self): + res = Environment(parent_env=self.parent_env, + name=self.name, stealing_env=self.stealing_env) - def carefulcopy(val): - if isinstance(val, Unfinished) or isinstance(val, RASPFunction): - return val # non mutable, at least not through rasp commands - elif isinstance(val, float) or isinstance(val, int) \ - or isinstance(val, str) or isinstance(val, bool): - return val # non mutable - elif isinstance(val, list): - return [carefulcopy(v) for v in val] - else: - raise RASPTypeError("environment contains element that is not " - + "unfinished, rasp function, float, int," - + "string, bool, or list? :", val) - res.constant_variables = {d: carefulcopy( - self.constant_variables[d]) for d in self.constant_variables} - res.variables = {d: carefulcopy( - self.variables[d]) for d in self.variables} - return res + def carefulcopy(val): + if isinstance(val, Unfinished) or isinstance(val, RASPFunction): + return val # non mutable, at least not through rasp commands + elif (isinstance(val, float) or isinstance(val, int) or + isinstance(val, str) or isinstance(val, bool)): + return val # non mutable + elif isinstance(val, list): + return [carefulcopy(v) for v in val] + else: + raise RASPTypeError("environment contains element that is " + + "not unfinished, rasp function, float, " + + "int, string, bool, or list? :", val) + res.constant_variables = {d: carefulcopy( + self.constant_variables[d]) for d in self.constant_variables} + res.variables = {d: carefulcopy( + self.variables[d]) for d in self.variables} + return res - def make_nested(self, names_vars=[]): - res = Environment(self, name=str(self.name)+"'") - for n, v in names_vars: - res.set_variable(n, v) - return res + def make_nested(self, names_vars=[]): + res = Environment(self, name=str(self.name) + "'") + for n, v in names_vars: + res.set_variable(n, v) + return res - def get_variable(self, name): - if name in self.constant_variables: - return self.constant_variables[name] - if name in self.variables: - return self.variables[name] - if self.parent_env is not None: - return self.parent_env.get_variable(name) - raise UndefinedVariable(name) + def get_variable(self, name): + if name in self.constant_variables: + return self.constant_variables[name] + if name in self.variables: + return self.variables[name] + if self.parent_env is not None: + return self.parent_env.get_variable(name) + raise UndefinedVariable(name) - def _set_checked_variable(self, name, val): - if self.storing_in_constants: - self.constant_variables[name] = val - self.reserved_words.append(name) - else: - self.variables[name] = val + def _set_checked_variable(self, name, val): + if self.storing_in_constants: + self.constant_variables[name] = val + self.reserved_words.append(name) + else: + self.variables[name] = val - def set_variable(self, name, val): - if name in self.reserved_words: - raise ReservedName(name) + def set_variable(self, name, val): + if name in self.reserved_words: + raise ReservedName(name) - self._set_checked_variable(name, val) - if self.stealing_env is not None: - if name.startswith("_") or name == "out": # things we don't want - # to steal - return - self.stealing_env.set_variable(name, val) + self._set_checked_variable(name, val) + if self.stealing_env is not None: + if name.startswith("_") or name == "out": # things we don't want + # to steal + return + self.stealing_env.set_variable(name, val) - def set_out(self, val): - self.variables["out"] = val + def set_out(self, val): + self.variables["out"] = val diff --git a/RASP_support/Evaluator.py b/RASP_support/Evaluator.py index e92f0e2..d30f6d2 100644 --- a/RASP_support/Evaluator.py +++ b/RASP_support/Evaluator.py @@ -1,6 +1,6 @@ from .FunctionalSupport import select, zipmap, aggregate, \ - or_selects, and_selects, not_select, indices, \ - Unfinished, UnfinishedSequence, UnfinishedSelect + or_selects, and_selects, not_select, indices, \ + Unfinished, UnfinishedSequence, UnfinishedSelect from .Sugar import tplor, tpland, tplnot, toseq, full_s from .Support import RASPTypeError, RASPError from collections.abc import Iterable @@ -11,761 +11,766 @@ def strdesc(o, desc_cap=None): - if isinstance(o, Unfinished): - return o.name - if isinstance(o, list): - res = "["+", ".join([strdesc(v) for v in o])+"]" - if desc_cap is not None and len(res) > desc_cap: - return "(list)" - else: - return res - if isinstance(o, dict): - res = "{"+", ".join((strdesc(k)+": "+strdesc(o[k])) for k in o)+"}" - if desc_cap is not None and len(res) > desc_cap: - return "(dict)" - else: - return res - else: - if isinstance(o, str): - return "\""+o+"\"" - else: - return str(o) + if isinstance(o, Unfinished): + return o.name + if isinstance(o, list): + res = "[" + ", ".join([strdesc(v) for v in o]) + "]" + if desc_cap is not None and len(res) > desc_cap: + return "(list)" + else: + return res + if isinstance(o, dict): + res = "{" + \ + ", ".join((strdesc(k) + ": " + strdesc(o[k])) for k in o) + "}" + if desc_cap is not None and len(res) > desc_cap: + return "(dict)" + else: + return res + else: + if isinstance(o, str): + return "\"" + o + "\"" + else: + return str(o) class RASPValueError(RASPError): - def __init__(self, *a): - super().__init__(*a) + def __init__(self, *a): + super().__init__(*a) DEBUG = False def debprint(*a, **kw): - if DEBUG: - print(*a, **kw) + if DEBUG: + print(*a, **kw) def ast_text(ast): # just so don't have to go remembering this somewhere - # consider seeing if can make it add spaces between the tokens when doing - # this tho - return ast.getText() + # consider seeing if can make it add spaces between the tokens when doing + # this tho + return ast.getText() def isatom(v): - # the legal atoms - return True in [isinstance(v, t) for t in [int, float, str, bool]] + # the legal atoms + return True in [isinstance(v, t) for t in [int, float, str, bool]] def name_general_type(v): - if isinstance(v, list): - return "list" - if isinstance(v, dict): - return "dict" - if isinstance(v, UnfinishedSequence): - return ENCODER_NAME - if isinstance(v, UnfinishedSelect): - return "selector" - if isinstance(v, RASPFunction): - return "function" - if isatom(v): - return "atom" - return "??" + if isinstance(v, list): + return "list" + if isinstance(v, dict): + return "dict" + if isinstance(v, UnfinishedSequence): + return ENCODER_NAME + if isinstance(v, UnfinishedSelect): + return "selector" + if isinstance(v, RASPFunction): + return "function" + if isatom(v): + return "atom" + return "??" class ArgsError(Exception): - def __init__(self, name, expected, got): - super().__init__("wrong number of args for "+name + - "- expected: "+str(expected)+", got: "+str(got)+".") + def __init__(self, name, expected, got): + super().__init__("wrong number of args for " + name + + "- expected: " + str(expected) + ", got: " + + str(got) + ".") class NamedVal: - def __init__(self, name, val): - self.name = name - self.val = val + def __init__(self, name, val): + self.name = name + self.val = val class NamedValList: - def __init__(self, namedvals): - self.nvs = namedvals + def __init__(self, namedvals): + self.nvs = namedvals class JustVal: - def __init__(self, val): - self.val = val + def __init__(self, val): + self.val = val class RASPFunction: - def __init__(self, name, enclosing_env, argnames, statement_trees, - returnexpr, creator_name): - self.name = name # just for debug purposes - self.enclosing_env = enclosing_env - self.argnames = argnames - self.statement_trees = statement_trees - self.returnexpr = returnexpr - self.creator = creator_name - - def __str__(self): - return self.creator + " function: " + self.name \ - + "(" + ", ".join(self.argnames) + ")" - - def call(self, *args): - top_eval = args[-1] - args = args[:-1] - # nesting, because function shouldn't affect the enclosing environment - env = self.enclosing_env.make_nested([]) - if not len(args) == len(self.argnames): - raise ArgsError(self.name, len(self.argnames), len(args)) - for n, v in zip(self.argnames, args): - env.set_variable(n, v) - evaluator = Evaluator(env, top_eval.repl) - for at in self.statement_trees: - evaluator.evaluate(at) - res = evaluator.evaluateExprsList(self.returnexpr) - return res[0] if len(res) == 1 else res + def __init__(self, name, enclosing_env, argnames, statement_trees, + returnexpr, creator_name): + self.name = name # just for debug purposes + self.enclosing_env = enclosing_env + self.argnames = argnames + self.statement_trees = statement_trees + self.returnexpr = returnexpr + self.creator = creator_name + + def __str__(self): + return self.creator + " function: " + self.name \ + + "(" + ", ".join(self.argnames) + ")" + + def call(self, *args): + top_eval = args[-1] + args = args[:-1] + # nesting, because function shouldn't affect the enclosing environment + env = self.enclosing_env.make_nested([]) + if not len(args) == len(self.argnames): + raise ArgsError(self.name, len(self.argnames), len(args)) + for n, v in zip(self.argnames, args): + env.set_variable(n, v) + evaluator = Evaluator(env, top_eval.repl) + for at in self.statement_trees: + evaluator.evaluate(at) + res = evaluator.evaluateExprsList(self.returnexpr) + return res[0] if len(res) == 1 else res class Evaluator: - def __init__(self, env, repl): - self.env = env - self.sequence_running_example = repl.sequence_running_example - self.backup_example = None - # allows evaluating something that maybe doesn't necessarily work with - # the main running example, but we just want to see what happens on - # it - e.g. so we can do draw(tokens_int+1,[1,2]) without error even - # while the main example is still "hello" - self.repl = repl - - def evaluate(self, ast): - if ast.expr(): - return self.evaluateExpr(ast.expr(), from_top=True) - if ast.assign(): - return self.assign(ast.assign()) - if ast.funcDef(): - return self.funcDef(ast.funcDef()) - if ast.draw(): - return self.draw(ast.draw()) - if ast.forLoop(): - return self.forLoop(ast.forLoop()) - if ast.loadFile(): - return self.repl.loadFile(ast.loadFile(), self.env) - - # more to come - raise NotImplementedError - - def draw(self, ast): - # TODO: make at least some rudimentary comparisons of selectors somehow - # to merge heads idk?????? maybe keep trace of operations used to - # create them and those with exact same parent s-ops and operations - # can get in? would still find eg select(0,0,==) and select(1,1,==) - # different, but its better than nothing at all - example = self.evaluateExpr( - ast.inputseq) if ast.inputseq else self.sequence_running_example - prev_backup = self.backup_example - self.backup_example = example - unf = self.evaluateExpr(ast.unf) - if not isinstance(unf, UnfinishedSequence): - raise RASPTypeError("draw expects unfinished sequence, got:", unf) - unf.draw_comp_flow(example) - res = unf.call(example) - res.created_from_input = example - self.backup_example = prev_backup - return JustVal(res) - - def assign(self, ast): - def set_val_and_name(val, name): - self.env.set_variable(name, val) - if isinstance(val, Unfinished): - val.setname(name) # completely irrelevant really for the REPL, - # but will help maintain sanity when printing computation flows - return NamedVal(name, val) - - varnames = self._names_list(ast.var) - values = self.evaluateExprsList(ast.val) - if len(values) == 1: - values = values[0] - - if len(varnames) == 1: - return set_val_and_name(values, varnames[0]) - else: - if not len(varnames) == len(values): - raise RASPTypeError("expected", len( - varnames), "values, but got:", len(values)) - reslist = [] - for v, name in zip(values, varnames): - reslist.append(set_val_and_name(v, name)) - return NamedValList(reslist) - - def _names_list(self, ast): - idsList = self._get_first_cont_list(ast) - return [i.text for i in idsList] - - def _set_iterator_and_vals(self, iterator_names, iterator_vals): - if len(iterator_names) == 1: - self.env.set_variable(iterator_names[0], iterator_vals) - elif isinstance(iterator_vals, Iterable) \ - and (len(iterator_vals) == len(iterator_names)): - for n, v in zip(iterator_names, iterator_vals): - self.env.set_variable(n, v) - else: - if not isinstance(iterator_vals, Iterable): - raise RASPTypeError( - "iterating with multiple iterator names, but got single" - + " iterator value:", iterator_vals) - else: - # should work out by logic of last failed elif - errnote = "something wrong with Evaluator logic" - assert not (len(iterator_vals) == len(iterator_names)), errnote - raise RASPTypeError("iterating with", len(iterator_names), - "names but got", len(iterator_vals), - "values (", iterator_vals, ")") - - def _evaluateDictComp(self, ast): - ast = ast.dictcomp - d = self.evaluateExpr(ast.iterable) - if not (isinstance(d, list) or isinstance(d, dict)): - raise RASPTypeError( - "dict comprehension should have got a list or dict to loop " - + "over, but got:", d) - res = {} - iterator_names = self._names_list(ast.iterator) - for vals in d: - orig_env = self.env - self.env = self.env.make_nested() - self._set_iterator_and_vals(iterator_names, vals) - key = self.make_dict_key(ast.key) - res[key] = self.evaluateExpr(ast.val) - self.env = orig_env - return res - - def _evaluateListComp(self, ast): - ast = ast.listcomp - ll = self.evaluateExpr(ast.iterable) - if not (isinstance(ll, list) or isinstance(ll, dict)): - raise RASPTypeError( - "list comprehension should have got a list or dict to loop " - + "over, but got:", ll) - res = [] - iterator_names = self._names_list(ast.iterator) - for vals in ll: - orig_env = self.env - self.env = self.env.make_nested() - # sets inside the now-nested env -don't want to keep - # the internal iterators after finishing this list comp - self._set_iterator_and_vals(iterator_names, vals) - res.append(self.evaluateExpr(ast.val)) - self.env = orig_env - return res - - def forLoop(self, ast): - iterator_names = self._names_list(ast.iterator) - iterable = self.evaluateExpr(ast.iterable) - if not (isinstance(iterable, list) or isinstance(iterable, dict)): - raise RASPTypeError( - "for loop needs to iterate over a list or dict, but got:", - iterable) - statements = self._get_first_cont_list(ast.mainbody) - for vals in iterable: - self._set_iterator_and_vals(iterator_names, vals) - for s in statements: - self.evaluate(s) - return JustVal(None) - - def _get_first_cont_list(self, ast): - res = [] - while ast: - if ast.first: - res.append(ast.first) - # sometimes there's no first cause it's just eating a comment - ast = ast.cont - return res - - def funcDef(self, ast): - funcname = ast.name.text - argname_trees = self._get_first_cont_list(ast.arguments) - argnames = [a.text for a in argname_trees] - statement_trees = self._get_first_cont_list(ast.mainbody) - returnexpr = ast.retstatement.res - res = RASPFunction(funcname, self.env, argnames, - statement_trees, returnexpr, self.env.name) - self.env.set_variable(funcname, res) - return NamedVal(funcname, res) - - def _evaluateUnaryExpr(self, ast): - uexpr = self.evaluateExpr(ast.uexpr) - uop = ast.uop.text - if uop == "not": - if isinstance(uexpr, UnfinishedSequence): - return tplnot(uexpr) - elif isinstance(uexpr, UnfinishedSelect): - return not_select(uexpr) - else: - return not uexpr - if uop == "-": - return -uexpr - if uop == "+": - return +uexpr - if uop == "round": - return round(uexpr) - if uop == "indicator": - if isinstance(uexpr, UnfinishedSequence): - name = "I("+uexpr.name+")" - zipmapped = zipmap(uexpr, lambda a: 1 if a else 0, name=name) - return zipmapped.allow_suppressing_display() - # naming res makes interpreter think it is important, i.e., - # must always be displayed. but here it has only been named for - # clarity, so correct it using .allow_suppressing_display() - - raise RASPTypeError( - "indicator operator expects "+ENCODER_NAME+", got:", uexpr) - raise NotImplementedError - - def _evaluateRange(self, ast): - valsList = self.evaluateExprsList(ast.rangevals) - if not len(valsList) in [1, 2, 3]: - raise RASPTypeError( - "wrong number of inputs to range, expected: 1, 2, or 3, got:", - len(valsList)) - for v in valsList: - if not isinstance(v, int): - raise RASPTypeError( - "range expects all integer inputs, but got:", - strdesc(valsList)) - return list(range(*valsList)) - - def _index_into_dict(self, d, index): - def invalid_key_error(i): - return RASPTypeError( - f"index into dict has to be {ENCODER_NAME} or atom" - + " (i.e., string, int, float, bool), got:", strdesc(i)) - - def missing_key_error(i): - return RASPValueError("index [", strdesc(i), "] not in dict.") - - dname, indexname = d.name, index.name - d, index = d.val, index.val - - if isinstance(index, UnfinishedSequence): - d = deepcopy(d) - def apply_d(i): - if i not in d: - raise missing_key_error(i) - return d[i] - name = f"{dname}[{indexname}]" - return zipmap((index,), apply_d, name=name) - elif not isatom(index): - raise invalid_key_error(index) - if index not in d: - raise missing_key_error(index) - else: - return d[index] - - def _index_into_list_or_str(self, ll, index): - lname, indexname = ll.name, index.name - ll, index = ll.val, index.val - ltype = "list" if isinstance(ll, list) else "string" - - def invalid_key_error(i): - return RASPTypeError(f"index into {ltype} has to be", - f"{ENCODER_NAME} or integer, got:", - strdesc(index)) - - def check_and_raise_key_error(i): - if i >= len(ll) or (-i) > len(ll): - raise RASPValueError("index", index, "out of range for", ltype, - "of length", len(ll)) - - if isinstance(index, UnfinishedSequence): - ll = deepcopy(ll) - def apply_l(i): - check_and_raise_key_error(i) - return ll[i] - name = f"{lname}[{indexname}]" - return zipmap((index,), apply_l, name=name) - elif not isinstance(index, int): - raise invalid_key_error(index) - check_and_raise_key_error(index) - return ll[index] - - def _index_into_sequence(self, s, index): - sname, indexname = s.name, index.name - s, index = s.val, index.val - if isinstance(index, int): - if index >= 0: - sel = select(toseq(index), indices, lambda q, - k: q == k, name="load from "+str(index)) - else: - length = self.env.get_variable("length") - real_index = length + index - real_index.setname(length.name+str(index)) - sel = select(real_index, indices, lambda q, - k: q == k, name="load from "+str(index)) - agg = aggregate(sel, s, name=s.name+"["+str(index)+"]") - return agg.allow_suppressing_display() - else: - raise RASPValueError( - "index into sequence has to be integer, got:", strdesc(index)) - - def _evaluateIndexing(self, ast): - indexable = self.evaluateExpr(ast.indexable, get_name=True) - index = self.evaluateExpr(ast.index, get_name=True) - - - if isinstance(indexable.val, list) or isinstance(indexable.val, str): - return self._index_into_list_or_str(indexable, index) - elif isinstance(indexable.val, dict): - return self._index_into_dict(indexable, index) - elif isinstance(indexable.val, UnfinishedSequence): - return self._index_into_sequence(indexable, index) - else: - raise RASPTypeError("can only index into a list, dict, string, or" - + " sequence, but instead got:", - strdesc(indexable.val)) - - def _evaluateSelectExpr(self, ast): - key = self.evaluateExpr(ast.key) - query = self.evaluateExpr(ast.query) - sop = ast.selop.text - key = toseq(key) # in case got an atom in one of these, - query = toseq(query) # e.g. selecting 0th index: indices @= 0 - if sop == "<": - return select(query, key, lambda q, k: q > k) - if sop == ">": - return select(query, key, lambda q, k: q < k) - if sop == "==": - return select(query, key, lambda q, k: q == k) - if sop == "!=": - return select(query, key, lambda q, k: not (q == k)) - if sop == "<=": - return select(query, key, lambda q, k: q >= k) - if sop == ">=": - return select(query, key, lambda q, k: q <= k) - - def _evaluateBinaryExpr(self, ast): - def has_sequence(left, right): - return isinstance(left, UnfinishedSequence) \ - or isinstance(right, UnfinishedSequence) - - def has_selector(left, right): - return isinstance(left, UnfinishedSelect) \ - or isinstance(right, UnfinishedSelect) - - def both_selectors(left, right): - return isinstance(left, UnfinishedSelect) \ - and isinstance(right, UnfinishedSelect) - left = self.evaluateExpr(ast.left) - right = self.evaluateExpr(ast.right) - bop = ast.bop.text - bad_pair = RASPTypeError( - "Cannot apply and/or between selector and non-selector") - if bop == "and": - if has_sequence(left, right): - if has_selector(left, right): - raise bad_pair - return tpland(left, right) - elif has_selector(left, right): - if not both_selectors(left, right): - raise bad_pair - return and_selects(left, right) - else: - return (left and right) - elif bop == "or": - if has_sequence(left, right): - if has_selector(left, right): - raise bad_pair - return tplor(left, right) - elif has_selector(left, right): - if not both_selectors(left, right): - raise bad_pair - return or_selects(left, right) - else: - return (left or right) - if has_selector(left, right): - raise RASPTypeError("Cannot apply", bop, "to selector(s)") - elif bop == "+": - return left + right - elif bop == "-": - return left - right - elif bop == "*": - return left * right - elif bop == "/": - return left/right - elif bop == "^": - return pow(left, right) - elif bop == '%': - return left % right - elif bop == "==": - return left == right - elif bop == "<=": - return left <= right - elif bop == ">=": - return left >= right - elif bop == "<": - return left < right - elif bop == ">": - return left > right - # more, like modulo and power and all the other operators, to come - raise NotImplementedError - - def _evaluateStandalone(self, ast): - if ast.anint: - return int(ast.anint.text) - if ast.afloat: - return float(ast.afloat.text) - if ast.astring: - return ast.astring.text[1:-1] - raise NotImplementedError - - def _evaluateTernaryExpr(self, ast): - cond = self.evaluateExpr(ast.cond) - if isinstance(cond, Unfinished): - res1 = self.evaluateExpr(ast.res1) - res2 = self.evaluateExpr(ast.res2) - cond, res1, res2 = tuple(map(toseq, (cond, res1, res2))) - return zipmap((cond, res1, res2), lambda c, r1, r2: r1 - if c else r2, name=res1.name+" if "+cond.name - + " else " + res2.name).allow_suppressing_display() - else: - return self.evaluateExpr(ast.res1) if cond \ - else self.evaluateExpr(ast.res2) - # lazy eval when cond is non-unfinished allows legal loops over - # actual atoms - - def _evaluateAggregateExpr(self, ast): - sel = self.evaluateExpr(ast.sel) - seq = self.evaluateExpr(ast.seq) - seq = toseq(seq) # just in case its an atom - default = self.evaluateExpr(ast.default) if ast.default else None - - if not isinstance(sel, UnfinishedSelect): - raise RASPTypeError("Expected selector, got:", strdesc(sel)) - if not isinstance(seq, UnfinishedSequence): - raise RASPTypeError("Expected sequence, got:", strdesc(seq)) - if isinstance(default, Unfinished): - raise RASPTypeError("Expected atom, got:", strdesc(default)) - return aggregate(sel, seq, default=default) - - def _evaluateZip(self, ast): - list_exps = self._get_first_cont_list(ast.lists) - lists = [self.evaluateExpr(e) for e in list_exps] - if not lists: - raise RASPTypeError("zip needs at least one list") - for i, l in enumerate(lists): - if not isinstance(l, list): - raise RASPTypeError( - "attempting to zip lists, but", i+1, - "-th element is not list:", strdesc(l)) - n = len(lists[0]) - for i, l in enumerate(lists): - if not len(l) == n: - raise RASPTypeError("attempting to zip lists of length", - n, ", but", i+1, "-th list has length", - len(l)) - # keep everything lists, no tuples/lists mixing here, all the same to - # rasp (no stuff like append etc) - return [list(v) for v in zip(*lists)] - - def make_dict_key(self, ast): - res = self.evaluateExpr(ast) - if not isatom(res): - raise RASPTypeError( - "dictionary keys can only be atoms, but instead got:", - strdesc(res)) - return res - - def _evaluateDict(self, ast): - named_exprs_list = self._get_first_cont_list(ast.dictContents) - return {self.make_dict_key(e.key): self.evaluateExpr(e.val) - for e in named_exprs_list} - - def _evaluateList(self, ast): - exprs_list = self._get_first_cont_list(ast.listContents) - return [self.evaluateExpr(e) for e in exprs_list] - - def _evaluateApplication(self, ast, unf): - input_vals = self._get_first_cont_list(ast.inputexprs) - if not len(input_vals) == 1: - raise ArgsError("evaluate unfinished", 1, len(input_vals)) - input_val = self.evaluateExpr(input_vals[0]) - if not isinstance(unf, Unfinished): - raise RASPTypeError("Applying unfinished expects to apply", - ENCODER_NAME, "or selector, got:", - strdesc(unf)) - if not isinstance(input_val, Iterable): - raise RASPTypeError( - "Applying unfinished expects iterable input, got:", - strdesc(input_val)) - res = unf.call(input_val) - res.created_from_input = input_val - return res - - def _evaluateRASPFunction(self, ast, raspfun): - args_trees = self._get_first_cont_list(ast.inputexprs) - args = tuple(self.evaluateExpr(t) for t in args_trees) + (self,) - real_args = args[:-1] - res = raspfun.call(*args) - if isinstance(res, Unfinished): - res.setname( - raspfun.name+"("+" , ".join(strdesc(a, desc_cap=20) - for a in real_args)+")") - return res - - def _evaluateContains(self, ast): - contained = self.evaluateExpr(ast.contained) - container = self.evaluateExpr(ast.container) - container_name = ast.container.var.text if ast.container.var \ - else str(container) - if isinstance(contained, UnfinishedSequence): - if not isinstance(container, list): - raise RASPTypeError("\"["+ENCODER_NAME+"] in X\" expects X to" - + "be list of atoms, but got non-list:", - strdesc(container)) - for v in container: - if not isatom(v): - raise RASPTypeError("\"["+ENCODER_NAME+"] in X\" expects X" - + "to be list of atoms, but got list " - + "with values:", strdesc(container)) - return zipmap(contained, lambda c: c in container, - name=contained.name + " in " - + container_name).allow_suppressing_display() - elif isatom(contained): # contained is now an atom - if isinstance(container, list): - return contained in container - elif isinstance(container, UnfinishedSequence): - indicator = zipmap(container, lambda v: int(v == contained)) - return aggregate(full_s, indicator) > 0 - else: - raise RASPTypeError( - "\"[atom] in X\" expects X to be list or " + ENCODER_NAME - + ", but got:", strdesc(container)) - if isinstance(contained, UnfinishedSelect) or isinstance(contained, - RASPFunction): - obj_name = "select" if isinstance( - contained, UnfinishedSelect) else "function" - raise RASPTypeError("don't check if", obj_name, - "is contained in list/dict: unless exact same " - + "instance, unable to check equivalence of", - obj_name + "s") - else: - raise RASPTypeError("\"A in X\" expects A to be", - ENCODER_NAME, "or atom, but got A:", - strdesc(contained)) - - def _evaluateLen(self, ast): - singleList = self.evaluateExpr(ast.singleList) - if not isinstance(singleList, list) or isinstance(singleList, dict): - raise RASPTypeError( - "attempting to compute length of non-list:", - strdesc(singleList)) - return len(singleList) - - def evaluateExprsList(self, ast): - exprsList = self._get_first_cont_list(ast) - return [self.evaluateExpr(v) for v in exprsList] - - def _test_res(self, res): - if isinstance(res, Unfinished): - def succeeds_with(exampe): - try: - res.call(example, just_pass_exception_up=True) - except Exception: - return False - else: - return True - succeeds_with_backup = (self.backup_example is not None) and \ - succeeds_with(self.backup_example) - if succeeds_with_backup: - return - succeeds_with_main = succeeds_with(self.sequence_running_example) - if succeeds_with_main: - return - example = self.sequence_running_example if self.backup_example \ - is None else self.backup_example - res.call(example, just_pass_exception_up=True) - - def evaluateExpr(self, ast, from_top=False, get_name=False): - def format_return(res, resname="out", - is_application_of_unfinished=False): - ast.evaled_value = res - # run a quick test of the result (by attempting to evaluate it on - # an example) to make sure there hasn't been some weird type - # problem, so it shouts even before someone actively tries to - # evaluate it - self._test_res(res) - - if is_application_of_unfinished: - return JustVal(res) - else: - self.env.set_out(res) - if from_top or get_name: - # this is when an expression has been evaled - return NamedVal(resname, res) - else: - return res - if ast.bracketed: # in parentheses - get out of them - return self.evaluateExpr(ast.bracketed, from_top=from_top) - if ast.var: # calling single variable - varname = ast.var.text - return format_return(self.env.get_variable(varname), - resname=varname) - if ast.standalone: - return format_return(self._evaluateStandalone(ast.standalone)) - if ast.bop: - return format_return(self._evaluateBinaryExpr(ast)) - if ast.uop: - return format_return(self._evaluateUnaryExpr(ast)) - if ast.cond: - return format_return(self._evaluateTernaryExpr(ast)) - if ast.aggregate: - return format_return(self._evaluateAggregateExpr(ast.aggregate)) - if ast.unfORfun: - - # before evaluating the unfORfun expression, - # consider that it may be an unf that would not work - # with the current running example, and allow that it may have - # been sent in with an example for which it will work - prev_backup = self.backup_example - input_vals = self._get_first_cont_list(ast.inputexprs) - if len(input_vals) == 1: - self.backup_example = self.evaluateExpr(input_vals[0]) - - unfORfun = self.evaluateExpr(ast.unfORfun) - - self.backup_example = prev_backup - - if isinstance(unfORfun, Unfinished): - return format_return(self._evaluateApplication(ast, unfORfun), - is_application_of_unfinished=True) - elif isinstance(unfORfun, RASPFunction): - return format_return(self._evaluateRASPFunction(ast, unfORfun)) - if ast.selop: - return format_return(self._evaluateSelectExpr(ast)) - if ast.aList(): - return format_return(self._evaluateList(ast.aList())) - if ast.aDict(): - return format_return(self._evaluateDict(ast.aDict())) - if ast.indexable: # indexing into a list, dict, or s-op - return format_return(self._evaluateIndexing(ast)) - if ast.rangevals: - return format_return(self._evaluateRange(ast)) - if ast.listcomp: - return format_return(self._evaluateListComp(ast)) - if ast.dictcomp: - return format_return(self._evaluateDictComp(ast)) - if ast.container: - return format_return(self._evaluateContains(ast)) - if ast.lists: - return format_return(self._evaluateZip(ast)) - if ast.singleList: - return format_return(self._evaluateLen(ast)) - raise NotImplementedError + def __init__(self, env, repl): + self.env = env + self.sequence_running_example = repl.sequence_running_example + self.backup_example = None + # allows evaluating something that maybe doesn't necessarily work with + # the main running example, but we just want to see what happens on + # it - e.g. so we can do draw(tokens_int+1,[1,2]) without error even + # while the main example is still "hello" + self.repl = repl + + def evaluate(self, ast): + if ast.expr(): + return self.evaluateExpr(ast.expr(), from_top=True) + if ast.assign(): + return self.assign(ast.assign()) + if ast.funcDef(): + return self.funcDef(ast.funcDef()) + if ast.draw(): + return self.draw(ast.draw()) + if ast.forLoop(): + return self.forLoop(ast.forLoop()) + if ast.loadFile(): + return self.repl.loadFile(ast.loadFile(), self.env) + + # more to come + raise NotImplementedError + + def draw(self, ast): + # TODO: make at least some rudimentary comparisons of selectors somehow + # to merge heads idk?????? maybe keep trace of operations used to + # create them and those with exact same parent s-ops and operations + # can get in? would still find eg select(0,0,==) and select(1,1,==) + # different, but its better than nothing at all + example = self.evaluateExpr( + ast.inputseq) if ast.inputseq else self.sequence_running_example + prev_backup = self.backup_example + self.backup_example = example + unf = self.evaluateExpr(ast.unf) + if not isinstance(unf, UnfinishedSequence): + raise RASPTypeError("draw expects unfinished sequence, got:", unf) + unf.draw_comp_flow(example) + res = unf.call(example) + res.created_from_input = example + self.backup_example = prev_backup + return JustVal(res) + + def assign(self, ast): + def set_val_and_name(val, name): + self.env.set_variable(name, val) + if isinstance(val, Unfinished): + val.setname(name) # completely irrelevant really for the REPL, + # but will help maintain sanity when printing computation flows + return NamedVal(name, val) + + varnames = self._names_list(ast.var) + values = self.evaluateExprsList(ast.val) + if len(values) == 1: + values = values[0] + + if len(varnames) == 1: + return set_val_and_name(values, varnames[0]) + else: + if not len(varnames) == len(values): + raise RASPTypeError("expected", len( + varnames), "values, but got:", len(values)) + reslist = [] + for v, name in zip(values, varnames): + reslist.append(set_val_and_name(v, name)) + return NamedValList(reslist) + + def _names_list(self, ast): + idsList = self._get_first_cont_list(ast) + return [i.text for i in idsList] + + def _set_iterator_and_vals(self, iterator_names, iterator_vals): + if len(iterator_names) == 1: + self.env.set_variable(iterator_names[0], iterator_vals) + elif (isinstance(iterator_vals, Iterable) and + (len(iterator_vals) == len(iterator_names))): + for n, v in zip(iterator_names, iterator_vals): + self.env.set_variable(n, v) + else: + if not isinstance(iterator_vals, Iterable): + raise RASPTypeError( + "iterating with multiple iterator names, but got single" + + " iterator value:", iterator_vals) + else: + # should work out by logic of last failed elif + errnote = "something wrong with Evaluator logic" + assert not (len(iterator_vals) == len(iterator_names)), errnote + raise RASPTypeError("iterating with", len(iterator_names), + "names but got", len(iterator_vals), + "values (", iterator_vals, ")") + + def _evaluateDictComp(self, ast): + ast = ast.dictcomp + d = self.evaluateExpr(ast.iterable) + if not (isinstance(d, list) or isinstance(d, dict)): + raise RASPTypeError( + "dict comprehension should have got a list or dict to loop " + + "over, but got:", d) + res = {} + iterator_names = self._names_list(ast.iterator) + for vals in d: + orig_env = self.env + self.env = self.env.make_nested() + self._set_iterator_and_vals(iterator_names, vals) + key = self.make_dict_key(ast.key) + res[key] = self.evaluateExpr(ast.val) + self.env = orig_env + return res + + def _evaluateListComp(self, ast): + ast = ast.listcomp + ll = self.evaluateExpr(ast.iterable) + if not (isinstance(ll, list) or isinstance(ll, dict)): + raise RASPTypeError( + "list comprehension should have got a list or dict to loop " + + "over, but got:", ll) + res = [] + iterator_names = self._names_list(ast.iterator) + for vals in ll: + orig_env = self.env + self.env = self.env.make_nested() + # sets inside the now-nested env -don't want to keep + # the internal iterators after finishing this list comp + self._set_iterator_and_vals(iterator_names, vals) + res.append(self.evaluateExpr(ast.val)) + self.env = orig_env + return res + + def forLoop(self, ast): + iterator_names = self._names_list(ast.iterator) + iterable = self.evaluateExpr(ast.iterable) + if not (isinstance(iterable, list) or isinstance(iterable, dict)): + raise RASPTypeError( + "for loop needs to iterate over a list or dict, but got:", + iterable) + statements = self._get_first_cont_list(ast.mainbody) + for vals in iterable: + self._set_iterator_and_vals(iterator_names, vals) + for s in statements: + self.evaluate(s) + return JustVal(None) + + def _get_first_cont_list(self, ast): + res = [] + while ast: + if ast.first: + res.append(ast.first) + # sometimes there's no first cause it's just eating a comment + ast = ast.cont + return res + + def funcDef(self, ast): + funcname = ast.name.text + argname_trees = self._get_first_cont_list(ast.arguments) + argnames = [a.text for a in argname_trees] + statement_trees = self._get_first_cont_list(ast.mainbody) + returnexpr = ast.retstatement.res + res = RASPFunction(funcname, self.env, argnames, + statement_trees, returnexpr, self.env.name) + self.env.set_variable(funcname, res) + return NamedVal(funcname, res) + + def _evaluateUnaryExpr(self, ast): + uexpr = self.evaluateExpr(ast.uexpr) + uop = ast.uop.text + if uop == "not": + if isinstance(uexpr, UnfinishedSequence): + return tplnot(uexpr) + elif isinstance(uexpr, UnfinishedSelect): + return not_select(uexpr) + else: + return not uexpr + if uop == "-": + return -uexpr + if uop == "+": + return +uexpr + if uop == "round": + return round(uexpr) + if uop == "indicator": + if isinstance(uexpr, UnfinishedSequence): + name = "I("+uexpr.name+")" + zipmapped = zipmap(uexpr, lambda a: 1 if a else 0, name=name) + return zipmapped.allow_suppressing_display() + # naming res makes interpreter think it is important, i.e., + # must always be displayed. but here it has only been named for + # clarity, so correct it using .allow_suppressing_display() + + raise RASPTypeError( + "indicator operator expects " + ENCODER_NAME + ", got:", uexpr) + raise NotImplementedError + + def _evaluateRange(self, ast): + valsList = self.evaluateExprsList(ast.rangevals) + if not len(valsList) in [1, 2, 3]: + raise RASPTypeError( + "wrong number of inputs to range, expected: 1, 2, or 3, got:", + len(valsList)) + for v in valsList: + if not isinstance(v, int): + raise RASPTypeError( + "range expects all integer inputs, but got:", + strdesc(valsList)) + return list(range(*valsList)) + + def _index_into_dict(self, d, index): + def invalid_key_error(i): + return RASPTypeError( + f"index into dict has to be {ENCODER_NAME} or atom" + + " (i.e., string, int, float, bool), got:", strdesc(i)) + + def missing_key_error(i): + return RASPValueError("index [", strdesc(i), "] not in dict.") + + dname, indexname = d.name, index.name + d, index = d.val, index.val + + if isinstance(index, UnfinishedSequence): + d = deepcopy(d) + + def apply_d(i): + if i not in d: + raise missing_key_error(i) + return d[i] + + name = f"{dname}[{indexname}]" + return zipmap((index,), apply_d, name=name) + elif not isatom(index): + raise invalid_key_error(index) + if index not in d: + raise missing_key_error(index) + else: + return d[index] + + def _index_into_list_or_str(self, ll, index): + lname, indexname = ll.name, index.name + ll, index = ll.val, index.val + ltype = "list" if isinstance(ll, list) else "string" + + def invalid_key_error(i): + return RASPTypeError(f"index into {ltype} has to be", + f"{ENCODER_NAME} or integer, got:", + strdesc(index)) + + def check_and_raise_key_error(i): + if i >= len(ll) or (-i) > len(ll): + raise RASPValueError("index", index, "out of range for", ltype, + "of length", len(ll)) + + if isinstance(index, UnfinishedSequence): + ll = deepcopy(ll) + + def apply_l(i): + check_and_raise_key_error(i) + return ll[i] + + name = f"{lname}[{indexname}]" + return zipmap((index,), apply_l, name=name) + elif not isinstance(index, int): + raise invalid_key_error(index) + check_and_raise_key_error(index) + return ll[index] + + def _index_into_sequence(self, s, index): + sname, indexname = s.name, index.name + s, index = s.val, index.val + if isinstance(index, int): + if index >= 0: + sel = select(toseq(index), indices, lambda q, + k: q == k, name="load from "+str(index)) + else: + length = self.env.get_variable("length") + real_index = length + index + real_index.setname(length.name+str(index)) + sel = select(real_index, indices, lambda q, + k: q == k, name="load from "+str(index)) + agg = aggregate(sel, s, name=s.name+"["+str(index)+"]") + return agg.allow_suppressing_display() + else: + raise RASPValueError( + "index into sequence has to be integer, got:", strdesc(index)) + + def _evaluateIndexing(self, ast): + indexable = self.evaluateExpr(ast.indexable, get_name=True) + index = self.evaluateExpr(ast.index, get_name=True) + + if isinstance(indexable.val, list) or isinstance(indexable.val, str): + return self._index_into_list_or_str(indexable, index) + elif isinstance(indexable.val, dict): + return self._index_into_dict(indexable, index) + elif isinstance(indexable.val, UnfinishedSequence): + return self._index_into_sequence(indexable, index) + else: + raise RASPTypeError("can only index into a list, dict, string," + + "or sequence, but instead got:", + strdesc(indexable.val)) + + def _evaluateSelectExpr(self, ast): + key = self.evaluateExpr(ast.key) + query = self.evaluateExpr(ast.query) + sop = ast.selop.text + key = toseq(key) # in case got an atom in one of these, + query = toseq(query) # e.g. selecting 0th index: indices @= 0 + if sop == "<": + return select(query, key, lambda q, k: q > k) + if sop == ">": + return select(query, key, lambda q, k: q < k) + if sop == "==": + return select(query, key, lambda q, k: q == k) + if sop == "!=": + return select(query, key, lambda q, k: not (q == k)) + if sop == "<=": + return select(query, key, lambda q, k: q >= k) + if sop == ">=": + return select(query, key, lambda q, k: q <= k) + + def _evaluateBinaryExpr(self, ast): + def has_sequence(left, right): + return isinstance(left, UnfinishedSequence) \ + or isinstance(right, UnfinishedSequence) + + def has_selector(left, right): + return isinstance(left, UnfinishedSelect) \ + or isinstance(right, UnfinishedSelect) + + def both_selectors(left, right): + return isinstance(left, UnfinishedSelect) \ + and isinstance(right, UnfinishedSelect) + left = self.evaluateExpr(ast.left) + right = self.evaluateExpr(ast.right) + bop = ast.bop.text + bad_pair = RASPTypeError( + "Cannot apply and/or between selector and non-selector") + if bop == "and": + if has_sequence(left, right): + if has_selector(left, right): + raise bad_pair + return tpland(left, right) + elif has_selector(left, right): + if not both_selectors(left, right): + raise bad_pair + return and_selects(left, right) + else: + return (left and right) + elif bop == "or": + if has_sequence(left, right): + if has_selector(left, right): + raise bad_pair + return tplor(left, right) + elif has_selector(left, right): + if not both_selectors(left, right): + raise bad_pair + return or_selects(left, right) + else: + return (left or right) + if has_selector(left, right): + raise RASPTypeError("Cannot apply", bop, "to selector(s)") + elif bop == "+": + return left + right + elif bop == "-": + return left - right + elif bop == "*": + return left * right + elif bop == "/": + return left/right + elif bop == "^": + return pow(left, right) + elif bop == '%': + return left % right + elif bop == "==": + return left == right + elif bop == "<=": + return left <= right + elif bop == ">=": + return left >= right + elif bop == "<": + return left < right + elif bop == ">": + return left > right + # more, like modulo and power and all the other operators, to come + raise NotImplementedError + + def _evaluateStandalone(self, ast): + if ast.anint: + return int(ast.anint.text) + if ast.afloat: + return float(ast.afloat.text) + if ast.astring: + return ast.astring.text[1: -1] + raise NotImplementedError + + def _evaluateTernaryExpr(self, ast): + cond = self.evaluateExpr(ast.cond) + if isinstance(cond, Unfinished): + res1 = self.evaluateExpr(ast.res1) + res2 = self.evaluateExpr(ast.res2) + cond, res1, res2 = tuple(map(toseq, (cond, res1, res2))) + return zipmap((cond, res1, res2), lambda c, r1, r2: r1 + if c else r2, name=res1.name + " if " + cond.name + + " else " + res2.name).allow_suppressing_display() + else: + return self.evaluateExpr(ast.res1) if cond \ + else self.evaluateExpr(ast.res2) + # lazy eval when cond is non-unfinished allows legal loops over + # actual atoms + + def _evaluateAggregateExpr(self, ast): + sel = self.evaluateExpr(ast.sel) + seq = self.evaluateExpr(ast.seq) + seq = toseq(seq) # just in case its an atom + default = self.evaluateExpr(ast.default) if ast.default else None + + if not isinstance(sel, UnfinishedSelect): + raise RASPTypeError("Expected selector, got:", strdesc(sel)) + if not isinstance(seq, UnfinishedSequence): + raise RASPTypeError("Expected sequence, got:", strdesc(seq)) + if isinstance(default, Unfinished): + raise RASPTypeError("Expected atom, got:", strdesc(default)) + return aggregate(sel, seq, default=default) + + def _evaluateZip(self, ast): + list_exps = self._get_first_cont_list(ast.lists) + lists = [self.evaluateExpr(e) for e in list_exps] + if not lists: + raise RASPTypeError("zip needs at least one list") + for i, l in enumerate(lists): + if not isinstance(l, list): + raise RASPTypeError( + "attempting to zip lists, but", i+1, + "-th element is not list:", strdesc(l)) + n = len(lists[0]) + for i, l in enumerate(lists): + if not len(l) == n: + raise RASPTypeError("attempting to zip lists of length", + n, ", but", i+1, "-th list has length", + len(l)) + # keep everything lists, no tuples/lists mixing here, all the same to + # rasp (no stuff like append etc) + return [list(v) for v in zip(*lists)] + + def make_dict_key(self, ast): + res = self.evaluateExpr(ast) + if not isatom(res): + raise RASPTypeError( + "dictionary keys can only be atoms, but instead got:", + strdesc(res)) + return res + + def _evaluateDict(self, ast): + named_exprs_list = self._get_first_cont_list(ast.dictContents) + return {self.make_dict_key(e.key): self.evaluateExpr(e.val) + for e in named_exprs_list} + + def _evaluateList(self, ast): + exprs_list = self._get_first_cont_list(ast.listContents) + return [self.evaluateExpr(e) for e in exprs_list] + + def _evaluateApplication(self, ast, unf): + input_vals = self._get_first_cont_list(ast.inputexprs) + if not len(input_vals) == 1: + raise ArgsError("evaluate unfinished", 1, len(input_vals)) + input_val = self.evaluateExpr(input_vals[0]) + if not isinstance(unf, Unfinished): + raise RASPTypeError("Applying unfinished expects to apply", + ENCODER_NAME, "or selector, got:", + strdesc(unf)) + if not isinstance(input_val, Iterable): + raise RASPTypeError( + "Applying unfinished expects iterable input, got:", + strdesc(input_val)) + res = unf.call(input_val) + res.created_from_input = input_val + return res + + def _evaluateRASPFunction(self, ast, raspfun): + args_trees = self._get_first_cont_list(ast.inputexprs) + args = tuple(self.evaluateExpr(t) for t in args_trees) + (self,) + real_args = args[:-1] + res = raspfun.call(*args) + if isinstance(res, Unfinished): + res.setname(raspfun.name + "(" + + " , ".join(strdesc(a, desc_cap=20) for + a in real_args) + ")") + return res + + def _evaluateContains(self, ast): + contained = self.evaluateExpr(ast.contained) + container = self.evaluateExpr(ast.container) + container_name = ast.container.var.text if ast.container.var \ + else str(container) + if isinstance(contained, UnfinishedSequence): + if not isinstance(container, list): + raise RASPTypeError(f"\"[{ENCODER_NAME}] in X\" expects X " + + "to be list of atoms, but got non-list:", + strdesc(container)) + for v in container: + if not isatom(v): + raise RASPTypeError("\"[" + ENCODER_NAME + "] in X\" " + + "expects X to be list of atoms, but " + + "got list with values:", + strdesc(container)) + return zipmap(contained, lambda c: c in container, + name=contained.name + " in " + + container_name).allow_suppressing_display() + elif isatom(contained): # contained is now an atom + if isinstance(container, list): + return contained in container + elif isinstance(container, UnfinishedSequence): + indicator = zipmap(container, lambda v: int(v == contained)) + return aggregate(full_s, indicator) > 0 + else: + raise RASPTypeError( + f"\"[atom] in X\" expects X to be list or {ENCODER_NAME}" + + ", but got:", strdesc(container)) + if isinstance(contained, UnfinishedSelect) or isinstance(contained, + RASPFunction): + obj_name = "select" if isinstance( + contained, UnfinishedSelect) else "function" + raise RASPTypeError(f"don't check if {obj_name} is contained" + + " in list/dict: unless exact same instance, " + + f"unable to check equivalence of {obj_name}s") + else: + raise RASPTypeError("\"A in X\" expects A to be", + ENCODER_NAME, "or atom, but got A:", + strdesc(contained)) + + def _evaluateLen(self, ast): + singleList = self.evaluateExpr(ast.singleList) + if not isinstance(singleList, list) or isinstance(singleList, dict): + raise RASPTypeError( + "attempting to compute length of non-list:", + strdesc(singleList)) + return len(singleList) + + def evaluateExprsList(self, ast): + exprsList = self._get_first_cont_list(ast) + return [self.evaluateExpr(v) for v in exprsList] + + def _test_res(self, res): + if isinstance(res, Unfinished): + def succeeds_with(exampe): + try: + res.call(example, just_pass_exception_up=True) + except Exception: + return False + else: + return True + succeeds_with_backup = (self.backup_example is not None) and \ + succeeds_with(self.backup_example) + if succeeds_with_backup: + return + succeeds_with_main = succeeds_with(self.sequence_running_example) + if succeeds_with_main: + return + example = self.sequence_running_example if self.backup_example \ + is None else self.backup_example + res.call(example, just_pass_exception_up=True) + + def evaluateExpr(self, ast, from_top=False, get_name=False): + def format_return(res, resname="out", + is_application_of_unfinished=False): + ast.evaled_value = res + # run a quick test of the result (by attempting to evaluate it on + # an example) to make sure there hasn't been some weird type + # problem, so it shouts even before someone actively tries to + # evaluate it + self._test_res(res) + + if is_application_of_unfinished: + return JustVal(res) + else: + self.env.set_out(res) + if from_top or get_name: + # this is when an expression has been evaled + return NamedVal(resname, res) + else: + return res + if ast.bracketed: # in parentheses - get out of them + return self.evaluateExpr(ast.bracketed, from_top=from_top) + if ast.var: # calling single variable + varname = ast.var.text + return format_return(self.env.get_variable(varname), + resname=varname) + if ast.standalone: + return format_return(self._evaluateStandalone(ast.standalone)) + if ast.bop: + return format_return(self._evaluateBinaryExpr(ast)) + if ast.uop: + return format_return(self._evaluateUnaryExpr(ast)) + if ast.cond: + return format_return(self._evaluateTernaryExpr(ast)) + if ast.aggregate: + return format_return(self._evaluateAggregateExpr(ast.aggregate)) + if ast.unfORfun: + + # before evaluating the unfORfun expression, + # consider that it may be an unf that would not work + # with the current running example, and allow that it may have + # been sent in with an example for which it will work + prev_backup = self.backup_example + input_vals = self._get_first_cont_list(ast.inputexprs) + if len(input_vals) == 1: + self.backup_example = self.evaluateExpr(input_vals[0]) + + unfORfun = self.evaluateExpr(ast.unfORfun) + + self.backup_example = prev_backup + + if isinstance(unfORfun, Unfinished): + return format_return(self._evaluateApplication(ast, unfORfun), + is_application_of_unfinished=True) + elif isinstance(unfORfun, RASPFunction): + return format_return(self._evaluateRASPFunction(ast, unfORfun)) + if ast.selop: + return format_return(self._evaluateSelectExpr(ast)) + if ast.aList(): + return format_return(self._evaluateList(ast.aList())) + if ast.aDict(): + return format_return(self._evaluateDict(ast.aDict())) + if ast.indexable: # indexing into a list, dict, or s-op + return format_return(self._evaluateIndexing(ast)) + if ast.rangevals: + return format_return(self._evaluateRange(ast)) + if ast.listcomp: + return format_return(self._evaluateListComp(ast)) + if ast.dictcomp: + return format_return(self._evaluateDictComp(ast)) + if ast.container: + return format_return(self._evaluateContains(ast)) + if ast.lists: + return format_return(self._evaluateZip(ast)) + if ast.singleList: + return format_return(self._evaluateLen(ast)) + raise NotImplementedError # new ast getText function for expressions def new_getText(self): # original getText function stored as self._getText - if hasattr(self, "evaled_value") and isatom(self.evaled_value): - return str(self.evaled_value) - else: - return self._getText() + if hasattr(self, "evaled_value") and isatom(self.evaled_value): + return str(self.evaled_value) + else: + return self._getText() RASPParser.ExprContext._getText = RASPParser.ExprContext.getText diff --git a/RASP_support/FunctionalSupport.py b/RASP_support/FunctionalSupport.py index 0bfef3b..98247fa 100644 --- a/RASP_support/FunctionalSupport.py +++ b/RASP_support/FunctionalSupport.py @@ -24,27 +24,27 @@ class NextId: - def __init__(self): - self.i = 0 + def __init__(self): + self.i = 0 - def get_next(self): - self.i += 1 - return self.i + def get_next(self): + self.i += 1 + return self.i unique_id_maker = NextId() def creation_order_id(): - return unique_id_maker.get_next() + return unique_id_maker.get_next() class AlreadyPrintedTheException: - def __init__(self): - self.b = False + def __init__(self): + self.b = False - def __bool__(self): - return self.b + def __bool__(self): + return self.b global_printed = AlreadyPrintedTheException() @@ -53,518 +53,525 @@ def __bool__(self): class Unfinished: - def __init__(self, parents_tuple, parents2self, name=plain_unfinished_name, - is_toplevel_input=False, min_poss_depth=-1): - self.parents_tuple = parents_tuple - self.parents2self = parents2self - self.last_w = None - self.last_res = None - self.is_toplevel_input = is_toplevel_input - self.setname(name if not self.is_toplevel_input else "input") - self.creation_order_id = creation_order_id() - self.min_poss_depth = min_poss_depth - self._real_parents = None - self._full_parents = None - self._sorted_full_parents = None - - def setname(self, name, always_display_when_named=True): - if name is not None: - if len(name) > name_maxlen: - if isinstance(self, UnfinishedSequence): - name = plain_unfinished_sequence_name - elif isinstance(self, UnfinishedSelect): - name = plain_unfinished_select_name - else: - name = plain_unfinished_name - self.name = name - # if you set something's name, you probably want to see it - self.always_display = always_display_when_named - # return self to allow chaining with other calls and throwing straight - # into a return statement etc - return self - - def get_parents(self): - if None is self._real_parents: - real_parents_part1 = [ - p for p in self.parents_tuple if is_real_unfinished(p)] - other_parents = [ - p for p in self.parents_tuple if not is_real_unfinished(p)] - res = real_parents_part1 - for p in other_parents: - # recursion: branch back through all the parents of the unf, - # always stopping wherever hit something 'real' ie a select or - # a sequence - res += p.get_parents() - # nothing is made from more than one select... - assert len( - [p for p in res if isinstance(p, UnfinishedSelect)]) <= 1 - self._real_parents = set(res) - # in case someone messes with the list eg popping through it - return copy(self._real_parents) - - def _flat_compute_full_parents(self): - # TODO: take advantage of anywhere full_parents have already been - # computed, tho, otherwise no point in doing the recursion ever - explored = set() - not_explored = set([self]) - while not_explored: - p = not_explored.pop() - if p in explored: - # this may happen due to also adding things directly to - # explored sometimes - continue - if None is not p._full_parents: - # note that _full_parents include self - explored.update(p._full_parents) - else: - new_parents = p.get_parents() - explored.add(p) - not_explored.update(new_parents) - return explored - - def _recursive_compute_full_parents(self): - res = self.get_parents() # get_parents returns a copy - res.update([self]) # full parents include self - for p in self.get_parents(): - res.update(p.get_full_parents(recurse=True, trusted=True)) - return res - - def _sort_full_parents(self): - if None is self._sorted_full_parents: - self._sorted_full_parents = sorted( - self._full_parents, key=lambda unf: unf.creation_order_id) - - def get_full_parents(self, recurse=False, just_compute=False, - trusted=False): - # Note: full_parents include self - if None is self._full_parents: - if recurse: - self._full_parents = self._recursive_compute_full_parents() - else: - self._full_parents = self._flat_compute_full_parents() - # avoids recursion, and so avoids passing the max recursion - # depth - - # but now having done that we would like to store the result - # for all parents so we can take advantage of it in the future - for p in self.get_sorted_full_parents(): - p.get_full_parents(recurse=True, just_compute=True) - # have them all compute their full parents so they are - # ready for the future, but only do this in sorted order, - # so recursion is always shallow. (always gets shorted with - # self._full_parents, which is being computed here for each - # unfinished starting from the top of the computation - # graph) - if not just_compute: - if trusted: - # functions where you have checked they don't modify the - # returned result can be marked as trusted and get the true - # _full_parents - return self._full_parents - else: - # otherwise they get a copy - return copy(self._full_parents) - - def get_sorted_full_parents(self): - # could have just made get_full_parents give a sorted result, but - # wanted a function where name is already clear that result will be - # sorted, to avoid weird bugs in future. (especially that being not - # sorted will only affect performance, and possibly break recursion - # depth) - - if None is self._sorted_full_parents: - if None is self._full_parents: - self.get_full_parents(just_compute=True) - self._sort_full_parents() - return copy(self._sorted_full_parents) - - def call(self, w, topcall=True, just_pass_exception_up=False): - if (not isinstance(w, Iterable)) or (not w): - raise RASPTypeError( - "RASP sequences/selectors expect non-empty iterables, got: " - + str(w)) - global_printed.b = False - if w == self.last_w: - return self.last_res # don't print same calculation multiple times - - else: - if self.is_toplevel_input: - res = w - self.last_w, self.last_res = w, w - else: - try: - if topcall: - # before doing the main call, evaluate all parents - # (in order of dependencies, attainable by using - # creation_order_id attribute), this avoids a deep - # recursion: every element that is evaluated only has - # to go back as far as its own 'real' (i.e., s-op or - # selector) parents to hit something that has already - # been evaluated, and then those will not recurse - # further back as they use memoization - for unf in self.get_sorted_full_parents(): - # evaluate - unf.call(w, topcall=False, - just_pass_exception_up=just_pass_exception_up) - - j_p_e_u = just_pass_exception_up - args = tuple(p.call(w, - topcall=False, - just_pass_exception_up=j_p_e_u) - for p in self.parents_tuple) - res = self.parents2self(*args) - except Exception as e: - if just_pass_exception_up: - raise e - if isinstance(e, RASPTypeError): - raise e - if not global_printed.b: - seperator = "=" * 63 - print(colored(f"{seperator}\n{seperator}", error_color)) - error_msg = f"evaluation failed in: [ {self.name} ]" +\ - f" with exception:\n {e}" - print(colored(error_msg, error_color)) - print(colored(seperator, error_color)) - print(colored("parent values are:", error_color)) - for p in self.parents_tuple: - print(colored( - f"=============\n{p.name}\n{p.last_res}", - error_color)) - print(colored(f"{seperator}\n{seperator}", error_color)) - a, b, tb = sys.exc_info() - tt = traceback.extract_tb(tb) - last_call = max([i for i, t in enumerate(tt) - if "in call" in str(t)]) - traceback_msg = \ - ''.join(traceback.format_list(tt[last_call+1:])) - print(colored(traceback_msg, error_color)) - - # traceback.print_exception(a,b,tb) - - global_printed.b = True - - if debug or not topcall: - raise - else: - return "EVALUATION FAILURE" - - self.last_w, self.last_res = w, res - return res + def __init__(self, parents_tuple, parents2self, name=plain_unfinished_name, + is_toplevel_input=False, min_poss_depth=-1): + self.parents_tuple = parents_tuple + self.parents2self = parents2self + self.last_w = None + self.last_res = None + self.is_toplevel_input = is_toplevel_input + self.setname(name if not self.is_toplevel_input else "input") + self.creation_order_id = creation_order_id() + self.min_poss_depth = min_poss_depth + self._real_parents = None + self._full_parents = None + self._sorted_full_parents = None + + def setname(self, name, always_display_when_named=True): + if name is not None: + if len(name) > name_maxlen: + if isinstance(self, UnfinishedSequence): + name = plain_unfinished_sequence_name + elif isinstance(self, UnfinishedSelect): + name = plain_unfinished_select_name + else: + name = plain_unfinished_name + self.name = name + # if you set something's name, you probably want to see it + self.always_display = always_display_when_named + # return self to allow chaining with other calls and throwing straight + # into a return statement etc + return self + + def get_parents(self): + if None is self._real_parents: + real_parents_part1 = [ + p for p in self.parents_tuple if is_real_unfinished(p)] + other_parents = [ + p for p in self.parents_tuple if not is_real_unfinished(p)] + res = real_parents_part1 + for p in other_parents: + # recursion: branch back through all the parents of the unf, + # always stopping wherever hit something 'real' ie a select or + # a sequence + res += p.get_parents() + # nothing is made from more than one select... + assert len( + [p for p in res if isinstance(p, UnfinishedSelect)]) <= 1 + self._real_parents = set(res) + # in case someone messes with the list eg popping through it + return copy(self._real_parents) + + def _flat_compute_full_parents(self): + # TODO: take advantage of anywhere full_parents have already been + # computed, tho, otherwise no point in doing the recursion ever + explored = set() + not_explored = set([self]) + while not_explored: + p = not_explored.pop() + if p in explored: + # this may happen due to also adding things directly to + # explored sometimes + continue + if None is not p._full_parents: + # note that _full_parents include self + explored.update(p._full_parents) + else: + new_parents = p.get_parents() + explored.add(p) + not_explored.update(new_parents) + return explored + + def _recursive_compute_full_parents(self): + res = self.get_parents() # get_parents returns a copy + res.update([self]) # full parents include self + for p in self.get_parents(): + res.update(p.get_full_parents(recurse=True, trusted=True)) + return res + + def _sort_full_parents(self): + if None is self._sorted_full_parents: + self._sorted_full_parents = sorted( + self._full_parents, key=lambda unf: unf.creation_order_id) + + def get_full_parents(self, recurse=False, just_compute=False, + trusted=False): + # Note: full_parents include self + if None is self._full_parents: + if recurse: + self._full_parents = self._recursive_compute_full_parents() + else: + self._full_parents = self._flat_compute_full_parents() + # avoids recursion, and so avoids passing the max recursion + # depth + + # but now having done that we would like to store the result + # for all parents so we can take advantage of it in the future + for p in self.get_sorted_full_parents(): + p.get_full_parents(recurse=True, just_compute=True) + # have them all compute their full parents so they are + # ready for the future, but only do this in sorted order, + # so recursion is always shallow. (always gets shorted with + # self._full_parents, which is being computed here for each + # unfinished starting from the top of the computation + # graph) + if not just_compute: + if trusted: + # functions where you have checked they don't modify the + # returned result can be marked as trusted and get the true + # _full_parents + return self._full_parents + else: + # otherwise they get a copy + return copy(self._full_parents) + + def get_sorted_full_parents(self): + # could have just made get_full_parents give a sorted result, but + # wanted a function where name is already clear that result will be + # sorted, to avoid weird bugs in future. (especially that being not + # sorted will only affect performance, and possibly break recursion + # depth) + + if None is self._sorted_full_parents: + if None is self._full_parents: + self.get_full_parents(just_compute=True) + self._sort_full_parents() + return copy(self._sorted_full_parents) + + def call(self, w, topcall=True, just_pass_exception_up=False): + if (not isinstance(w, Iterable)) or (not w): + raise RASPTypeError( + "RASP sequences/selectors expect non-empty iterables, got: " + + str(w)) + global_printed.b = False + if w == self.last_w: + return self.last_res # don't print same calculation multiple times + + else: + if self.is_toplevel_input: + res = w + self.last_w, self.last_res = w, w + else: + try: + j_p_e_u = just_pass_exception_up + if topcall: + # before doing the main call, evaluate all parents + # (in order of dependencies, attainable by using + # creation_order_id attribute), this avoids a deep + # recursion: every element that is evaluated only has + # to go back as far as its own 'real' (i.e., s-op or + # selector) parents to hit something that has already + # been evaluated, and then those will not recurse + # further back as they use memoization + for unf in self.get_sorted_full_parents(): + # evaluate + unf.call(w, topcall=False, + just_pass_exception_up=j_p_e_u) + args = tuple(p.call(w, topcall=False, + just_pass_exception_up=j_p_e_u) + for p in self.parents_tuple) + res = self.parents2self(*args) + except Exception as e: + if just_pass_exception_up: + raise e + if isinstance(e, RASPTypeError): + raise e + if not global_printed.b: + seperator = "=" * 63 + print(colored(f"{seperator}\n{seperator}", + error_color)) + error_msg = f"evaluation failed in: [ {self.name} ]" +\ + f" with exception:\n {e}" + print(colored(error_msg, error_color)) + print(colored(seperator, error_color)) + print(colored("parent values are:", error_color)) + for p in self.parents_tuple: + print(colored( + f"=============\n{p.name}\n{p.last_res}", + error_color)) + print(colored(f"{seperator}\n{seperator}", + error_color)) + a, b, tb = sys.exc_info() + tt = traceback.extract_tb(tb) + last_call = max([i for i, t in enumerate(tt) + if "in call" in str(t)]) + traceback_msg = \ + ''.join(traceback.format_list(tt[last_call + 1:])) + print(colored(traceback_msg, error_color)) + + # traceback.print_exception(a,b,tb) + + global_printed.b = True + + if debug or not topcall: + raise + else: + return "EVALUATION FAILURE" + + self.last_w, self.last_res = w, res + return res class UnfinishedSequence(Unfinished): - def __init__(self, parents_tuple, parents2self, - name=plain_unfinished_sequence_name, - elementwise_function=None, default=None, min_poss_depth=0, - from_zipmap=False, output_index=-1, - definitely_uses_identity_function=False): - # min_poss_depth=0 starts all of the base sequences (eg indices) off - # right. - - # might have got none from some default value, fix it before continuing - # because later things eg DrawCompFlow will expect name to be str - if name is None: - name = plain_unfinished_sequence_name - super(UnfinishedSequence, self).__init__(parents_tuple, - parents2self, name=name, - min_poss_depth=min_poss_depth) - # can be inferred (by seeing if there are parent selects), but this is - # simple enough. helpful for rendering comp flow visualisations - self.from_zipmap = from_zipmap - # useful for analysis later - self.elementwise_function = elementwise_function - self.output_index = output_index - # useful for analysis later - self.default = default - self.definitely_uses_identity_function = \ - definitely_uses_identity_function - self.never_display = False - self._constant = False - - def __str__(self): - id = str(self.creation_order_id) - return "UnfinishedSequence object, name: " + self.name + " id: " + id - - def mark_as_constant(self): - self._constant = True - return self - - def is_constant(self): - return self._constant + def __init__(self, parents_tuple, parents2self, + name=plain_unfinished_sequence_name, + elementwise_function=None, default=None, min_poss_depth=0, + from_zipmap=False, output_index=-1, + definitely_uses_identity_function=False): + # min_poss_depth=0 starts all of the base sequences (eg indices) off + # right. + + # might have got none from some default value, fix it before continuing + # because later things eg DrawCompFlow will expect name to be str + if name is None: + name = plain_unfinished_sequence_name + super(UnfinishedSequence, self).__init__(parents_tuple, + parents2self, name=name, + min_poss_depth=min_poss_depth) + # can be inferred (by seeing if there are parent selects), but this is + # simple enough. helpful for rendering comp flow visualisations + self.from_zipmap = from_zipmap + # useful for analysis later + self.elementwise_function = elementwise_function + self.output_index = output_index + # useful for analysis later + self.default = default + self.definitely_uses_identity_function = \ + definitely_uses_identity_function + self.never_display = False + self._constant = False + + def __str__(self): + id = str(self.creation_order_id) + return "UnfinishedSequence object, name: " + self.name + " id: " + id + + def mark_as_constant(self): + self._constant = True + return self + + def is_constant(self): + return self._constant class UnfinishedSelect(Unfinished): - def __init__(self, parents_tuple, parents2self, - name=plain_unfinished_select_name, compare_string=None, - min_poss_depth=-1, q_vars=None, k_vars=None, - orig_selector=None): # selects should be told their depth, - # -1 will warn of problems properly - if name is None: # as in unfinishedsequence, some other function might - # have passed in a None somewhere - name = plain_unfinished_select_name # so fix before a print goes - # wrong - super(UnfinishedSelect, self).__init__(parents_tuple, - parents2self, name=name, - min_poss_depth=min_poss_depth) - self.compare_string = str( - self.creation_order_id) if compare_string is None \ - else compare_string - # they're not really optional i just dont want to add more mess to the - # func - assert None not in [q_vars, k_vars] - self.q_vars = q_vars # don't actually need them, but useful for - self.k_vars = k_vars # drawing comp flow - # use compare string for comparison/uniqueness rather than overloading - # __eq__ of unfinishedselect, to avoid breaking things in unknown - # locations, and to be able to put selects in dictionaries and stuff - # (overloading __eq__ makes an object unhasheable unless i guess you - # overload the hash too?). need these comparisons for optimisations in - # analysis eg if two selects are identical they can be same head - self.orig_selector = orig_selector # for comfortable compositions of - # selectors - - def __str__(self): - id = str(self.creation_order_id) - return "UnfinishedSelect object, name: " + self.name + " id: " + id + def __init__(self, parents_tuple, parents2self, + name=plain_unfinished_select_name, compare_string=None, + min_poss_depth=-1, q_vars=None, k_vars=None, + orig_selector=None): # selects should be told their depth, + # -1 will warn of problems properly + if name is None: # as in unfinishedsequence, some other function might + # have passed in a None somewhere + name = plain_unfinished_select_name # so fix before a print goes + # wrong + super(UnfinishedSelect, self).__init__(parents_tuple, + parents2self, name=name, + min_poss_depth=min_poss_depth) + self.compare_string = str( + self.creation_order_id) if compare_string is None \ + else compare_string + # they're not really optional i just dont want to add more mess to the + # func + assert None not in [q_vars, k_vars] + self.q_vars = q_vars # don't actually need them, but useful for + self.k_vars = k_vars # drawing comp flow + # use compare string for comparison/uniqueness rather than overloading + # __eq__ of unfinishedselect, to avoid breaking things in unknown + # locations, and to be able to put selects in dictionaries and stuff + # (overloading __eq__ makes an object unhasheable unless i guess you + # overload the hash too?). need these comparisons for optimisations in + # analysis eg if two selects are identical they can be same head + self.orig_selector = orig_selector # for comfortable compositions of + # selectors + + def __str__(self): + id = str(self.creation_order_id) + return "UnfinishedSelect object, name: " + self.name + " id: " + id # as opposed to intermediate unfinisheds like tuples of sequences def is_real_unfinished(unf): - return isinstance(unf, UnfinishedSequence) \ - or isinstance(unf, UnfinishedSelect) + return isinstance(unf, UnfinishedSequence) \ + or isinstance(unf, UnfinishedSelect) # some tiny bit of sugar that fits here: def is_sequence_of_unfinishedseqs(seqs): - if not isinstance(seqs, Iterable): - return False - return False not in [isinstance(seq, UnfinishedSequence) for seq in seqs] + if not isinstance(seqs, Iterable): + return False + return False not in [isinstance(seq, UnfinishedSequence) for seq in seqs] class BareBonesFunctionalSupportException(Exception): - def __init__(self, m): - Exception.__init__(self, m) + def __init__(self, m): + Exception.__init__(self, m) def to_tuple_of_unfinishedseqs(seqs): - if is_sequence_of_unfinishedseqs(seqs): - return tuple(seqs) - if isinstance(seqs, UnfinishedSequence): - return (seqs,) - print(colored(f"seqs: {seqs}", general_color)) - raise BareBonesFunctionalSupportException( - "input to select/aggregate not an unfinished sequence or sequence of" - + " unfinished sequences") + if is_sequence_of_unfinishedseqs(seqs): + return tuple(seqs) + if isinstance(seqs, UnfinishedSequence): + return (seqs,) + print(colored(f"seqs: {seqs}", general_color)) + raise BareBonesFunctionalSupportException( + "input to select/aggregate not an unfinished sequence or sequence" + + "of unfinished sequences") def tup2tup(*x): - return tuple([*x]) + return tuple([*x]) class UnfinishedSequencesTuple(Unfinished): - def __init__(self, parents_tuple, parents2self=None): - # sequence tuples only exist in here, user doesn't 'see' them. can have - # lots of default values they're just a convenience for me - if parents2self is None: # just sticking a bunch of unfinished - # sequences together into one thing for reasons - parents2self = tup2tup - parents_tuple = to_tuple_of_unfinishedseqs(parents_tuple) - assert is_sequence_of_unfinishedseqs( - parents_tuple) and isinstance(parents_tuple, tuple) - # else - probably creating several sequences at once from one aggregate - super(UnfinishedSequencesTuple, self).__init__( - parents_tuple, parents2self, name="plain unfinished tuple") - - def __add__(self, other): - assert isinstance(other, UnfinishedSequencesTuple) - assert self.parents2self is tup2tup - assert other.parents2self is tup2tup - return UnfinishedSequencesTuple(self.parents_tuple+other.parents_tuple) + def __init__(self, parents_tuple, parents2self=None): + # sequence tuples only exist in here, user doesn't 'see' them. can have + # lots of default values they're just a convenience for me + if parents2self is None: # just sticking a bunch of unfinished + # sequences together into one thing for reasons + parents2self = tup2tup + parents_tuple = to_tuple_of_unfinishedseqs(parents_tuple) + assert is_sequence_of_unfinishedseqs( + parents_tuple) and isinstance(parents_tuple, tuple) + # else - probably creating several sequences at once from one aggregate + super(UnfinishedSequencesTuple, self).__init__( + parents_tuple, parents2self, name="plain unfinished tuple") + + def __add__(self, other): + assert isinstance(other, UnfinishedSequencesTuple) + assert self.parents2self is tup2tup + assert other.parents2self is tup2tup + return UnfinishedSequencesTuple(self.parents_tuple + + other.parents_tuple) _input = Unfinished((), None, is_toplevel_input=True) # and now, the actual exposed functions indices = UnfinishedSequence((_input,), lambda w: Sequence( - list(range(len(w)))), name=plain_indices) + list(range(len(w)))), name=plain_indices) tokens_str = UnfinishedSequence((_input,), lambda w: Sequence( - list(map(str, w))), name=plain_tokens+"_str") + list(map(str, w))), name=plain_tokens + "_str") tokens_int = UnfinishedSequence((_input,), lambda w: Sequence( - list(map(int, w))), name=plain_tokens+"_int") + list(map(int, w))), name=plain_tokens + "_int") tokens_float = UnfinishedSequence((_input,), lambda w: Sequence( - list(map(float, w))), name=plain_tokens+"_float") + list(map(float, w))), name=plain_tokens + "_float") tokens_bool = UnfinishedSequence((_input,), lambda w: Sequence( - list(map(bool, w))), name=plain_tokens+"_bool") + list(map(bool, w))), name=plain_tokens + "_bool") tokens_asis = UnfinishedSequence( - (_input,), lambda w: Sequence(w), name=plain_tokens+"_asis") + (_input,), lambda w: Sequence(w), name=plain_tokens + "_asis") base_tokens = [tokens_str, tokens_int, tokens_float, tokens_bool, tokens_asis] def _min_poss_depth(unfs): - if isinstance(unfs, Unfinished): # got single unfinished and not iterable - # of them - unfs = [unfs] - # max b/c cant go less deep than deepest - return max([u.min_poss_depth for u in unfs]+[0]) - # add that 0 thing so list is never empty and max complains. + if isinstance(unfs, Unfinished): # got single unfinished and not iterable + # of them + unfs = [unfs] + # max b/c cant go less deep than deepest + return max([u.min_poss_depth for u in unfs] + [0]) + # add that 0 thing so list is never empty and max complains. def tupleise(v): - if isinstance(v, tuple) or isinstance(v, list): - return tuple(v) - return (v,) + if isinstance(v, tuple) or isinstance(v, list): + return tuple(v) + return (v,) def select(q_vars, k_vars, selector, name=None, compare_string=None): - if None is name: - name = "plain select" - # potentially here check the qvars all reference the same input sequence as - # each other and same for the kvars, technically dont *have* to but is - # helpful for the user so consider maybe adding a tiny bit of mess here - # (including markings inside sequences and selectors so they know which - # index they're gathering to and from) to allow it - - # we're ok with getting a single q or k var, not in a tuple, - # but important to fix it before '+' on two UnfinishedSequences - # (as opposed to two tuples) sends everything sideways - q_vars = tupleise(q_vars) - k_vars = tupleise(k_vars) - - # attn layer is one after values it needs to be calculated - new_depth = _min_poss_depth(q_vars+k_vars)+1 - res = UnfinishedSelect((_input, # need input seq length to create select - # of correct size - UnfinishedSequencesTuple(q_vars), - UnfinishedSequencesTuple(k_vars)), - lambda input_seq, qv, kv: _select( - len(input_seq), qv, kv, selector), - name=name, compare_string=compare_string, - min_poss_depth=new_depth, q_vars=q_vars, - k_vars=k_vars, orig_selector=selector) - return res + if None is name: + name = "plain select" + # potentially here check the qvars all reference the same input sequence as + # each other and same for the kvars, technically dont *have* to but is + # helpful for the user so consider maybe adding a tiny bit of mess here + # (including markings inside sequences and selectors so they know which + # index they're gathering to and from) to allow it + + # we're ok with getting a single q or k var, not in a tuple, + # but important to fix it before '+' on two UnfinishedSequences + # (as opposed to two tuples) sends everything sideways + q_vars = tupleise(q_vars) + k_vars = tupleise(k_vars) + + # attn layer is one after values it needs to be calculated + new_depth = _min_poss_depth(q_vars + k_vars) + 1 + res = UnfinishedSelect((_input, # need input seq length to create select + # of correct size + UnfinishedSequencesTuple(q_vars), + UnfinishedSequencesTuple(k_vars)), + lambda input_seq, qv, kv: _select( + len(input_seq), qv, kv, selector), + name=name, compare_string=compare_string, + min_poss_depth=new_depth, q_vars=q_vars, + k_vars=k_vars, orig_selector=selector) + return res def _compose_selects(select1, select2, compose_op=None, name=None, - compare_string=None): - nq1 = len(select1.q_vars) - nq2 = len(select2.q_vars)+nq1 - nk1 = len(select1.k_vars)+nq2 - - def new_selector(*qqkk): - q1 = qqkk[:nq1] - q2 = qqkk[nq1:nq2] - k1 = qqkk[nq2:nk1] - k2 = qqkk[nk1:] - return compose_op(select1.orig_selector(*q1, *k1), - select2.orig_selector(*q2, *k2)) - return select(select1.q_vars+select2.q_vars, - select1.k_vars+select2.k_vars, - new_selector, name=name, compare_string=compare_string) + compare_string=None): + nq1 = len(select1.q_vars) + nq2 = len(select2.q_vars) + nq1 + nk1 = len(select1.k_vars) + nq2 + + def new_selector(*qqkk): + q1 = qqkk[:nq1] + q2 = qqkk[nq1:nq2] + k1 = qqkk[nq2:nk1] + k2 = qqkk[nk1:] + return compose_op(select1.orig_selector(*q1, *k1), + select2.orig_selector(*q2, *k2)) + return select(select1.q_vars + select2.q_vars, + select1.k_vars + select2.k_vars, + new_selector, name=name, compare_string=compare_string) def _compose_select(select1, compose_op=None, name=None, compare_string=None): - def new_selector(*qk): - return compose_op(select1.orig_selector(*qk)) - return select(select1.q_vars, - select1.k_vars, - new_selector, name=name, compare_string=compare_string) + def new_selector(*qk): + return compose_op(select1.orig_selector(*qk)) + return select(select1.q_vars, + select1.k_vars, + new_selector, name=name, compare_string=compare_string) def not_select(select, name=None, compare_string=None): - return _compose_select(select, lambda a: not a, - name=name, compare_string=compare_string) + return _compose_select(select, lambda a: not a, + name=name, compare_string=compare_string) def and_selects(select1, select2, name=None, compare_string=None): - return _compose_selects(select1, select2, lambda a, b: a and b, - name=name, compare_string=compare_string) + return _compose_selects(select1, select2, lambda a, b: a and b, + name=name, compare_string=compare_string) def or_selects(select1, select2, name=None, compare_string=None): - return _compose_selects(select1, select2, lambda a, b: a or b, - name=name, compare_string=compare_string) + return _compose_selects(select1, select2, lambda a, b: a or b, + name=name, compare_string=compare_string) def format_output(parents_tuple, parents2res, name, elementwise_function=None, - default=None, min_poss_depth=0, from_zipmap=False, - definitely_uses_identity_function=False): - def_uses = definitely_uses_identity_function - return UnfinishedSequence(parents_tuple, parents2res, - elementwise_function=elementwise_function, - default=default, name=name, - min_poss_depth=min_poss_depth, - from_zipmap=from_zipmap, - definitely_uses_identity_function=def_uses) + default=None, min_poss_depth=0, from_zipmap=False, + definitely_uses_identity_function=False): + def_uses = definitely_uses_identity_function + return UnfinishedSequence(parents_tuple, parents2res, + elementwise_function=elementwise_function, + default=default, name=name, + min_poss_depth=min_poss_depth, + from_zipmap=from_zipmap, + definitely_uses_identity_function=def_uses) def get_identity_function(num_params): - def identity1(a): - return a + def identity1(a): + return a - def identityx(*a): - return a - return identity1 if num_params == 1 else identityx + def identityx(*a): + return a + return identity1 if num_params == 1 else identityx def zipmap(sequences_tuple, elementwise_function, - name=plain_unfinished_sequence_name): - sequences_tuple = tupleise(sequences_tuple) - unfinished_parents_tuple = UnfinishedSequencesTuple( - sequences_tuple) # this also takes care of turning the - # value in sequences_tuple to indeed a tuple of sequences and not eg a - # single sequence which will cause weird behaviour later - - parents_tuple = (_input, unfinished_parents_tuple) - def parents2res(w, vt): return _zipmap(len(w), vt, elementwise_function) - # feedforward doesn't increase layer - min_poss_depth = _min_poss_depth(sequences_tuple) - # new assumption, to be revised later: can do arbitrary zipmap even before - # first feed-forward, i.e. in build up to first attention. truth is can do - # 'simple' zipmap towards first attention (no xor, but yes things like - # 'and' or 'indicator for ==' or whatever) based on initial linear - # translation done for Q,K in attention (not deep enough for xor, but deep - # enough for simple stuff) alongside use of initial embedding. honestly - # literally can just put everything in initial embedding if need it so bad - # its the first layer and its zipmap its only a function of the token and - # indices, so long as its not computing any weird combination between them - # you can do it in the embedding - # if len(sequences_tuple)>0: - # min_poss_depth = max(min_poss_depth,1) # except for the very specific - # # case where it is the very first thing to be done, in which case we do - # # have to go through one layer to get to the first feedforward. - # # the 'if' is there to rule out increasing when doing a feedforward on - # # nothing, ie, when making a constant. constants are allowed to be - # # created on layer 0, they're part of the embedding or the weights that - # # will use them later or whatever, it's fine - - # at least as deep as needed MVs, but no deeper cause FF - # (which happens at end of layer) - return format_output(parents_tuple, parents2res, name, - min_poss_depth=min_poss_depth, - elementwise_function=elementwise_function, - from_zipmap=True) + name=plain_unfinished_sequence_name): + sequences_tuple = tupleise(sequences_tuple) + unfinished_parents_tuple = UnfinishedSequencesTuple( + sequences_tuple) # this also takes care of turning the + # value in sequences_tuple to indeed a tuple of sequences and not eg a + # single sequence which will cause weird behaviour later + + parents_tuple = (_input, unfinished_parents_tuple) + + def parents2res(w, vt): + return _zipmap(len(w), vt, elementwise_function) + + # feedforward doesn't increase layer + min_poss_depth = _min_poss_depth(sequences_tuple) + # new assumption, to be revised later: can do arbitrary zipmap even before + # first feed-forward, i.e. in build up to first attention. truth is can do + # 'simple' zipmap towards first attention (no xor, but yes things like + # 'and' or 'indicator for ==' or whatever) based on initial linear + # translation done for Q,K in attention (not deep enough for xor, but deep + # enough for simple stuff) alongside use of initial embedding. honestly + # literally can just put everything in initial embedding if need it so bad + # its the first layer and its zipmap its only a function of the token and + # indices, so long as its not computing any weird combination between them + # you can do it in the embedding + # if len(sequences_tuple)>0: + # min_poss_depth = max(min_poss_depth,1) # except for the very specific + # # case where it is the very first thing to be done, in which case we do + # # have to go through one layer to get to the first feedforward. + # # the 'if' is there to rule out increasing when doing a feedforward on + # # nothing, ie, when making a constant. constants are allowed to be + # # created on layer 0, they're part of the embedding or the weights that + # # will use them later or whatever, it's fine + + # at least as deep as needed MVs, but no deeper cause FF + # (which happens at end of layer) + return format_output(parents_tuple, parents2res, name, + min_poss_depth=min_poss_depth, + elementwise_function=elementwise_function, + from_zipmap=True) def aggregate(select, sequences_tuple, elementwise_function=None, - default=None, name=plain_unfinished_sequence_name): - sequences_tuple = tupleise(sequences_tuple) - definitely_uses_identity_function = None is elementwise_function - if definitely_uses_identity_function: - elementwise_function = get_identity_function(len(sequences_tuple)) - unfinished_parents_tuple = UnfinishedSequencesTuple(sequences_tuple) - parents_tuple = (select, unfinished_parents_tuple) - def parents2res(s, vt): return _aggregate( - s, vt, elementwise_function, default=default) - def_uses = definitely_uses_identity_function - - # at least as deep as needed attention and at least one deeper than needed - # MVs - return format_output(parents_tuple, parents2res, name, - elementwise_function=elementwise_function, - default=default, - min_poss_depth=max(_min_poss_depth( - sequences_tuple)+1, select.min_poss_depth), - definitely_uses_identity_function=def_uses) + default=None, name=plain_unfinished_sequence_name): + sequences_tuple = tupleise(sequences_tuple) + definitely_uses_identity_function = None is elementwise_function + if definitely_uses_identity_function: + elementwise_function = get_identity_function(len(sequences_tuple)) + unfinished_parents_tuple = UnfinishedSequencesTuple(sequences_tuple) + parents_tuple = (select, unfinished_parents_tuple) + + def parents2res(s, vt): + return _aggregate(s, vt, elementwise_function, default=default) + + def_uses = definitely_uses_identity_function + + # at least as deep as needed attention and at least one deeper than needed + # MVs + min_poss_depth = max(_min_poss_depth(sequences_tuple) + 1, + select.min_poss_depth) + return format_output(parents_tuple, parents2res, name, + elementwise_function=elementwise_function, + default=default, + min_poss_depth=min_poss_depth, + definitely_uses_identity_function=def_uses) # up to here was just plain transformer 'assembly'. any addition is a lie @@ -572,18 +579,18 @@ def parents2res(s, vt): return _aggregate( def UnfinishedSequenceFunc(f): - setattr(UnfinishedSequence, f.__name__, f) + setattr(UnfinishedSequence, f.__name__, f) def UnfinishedFunc(f): - setattr(Unfinished, f.__name__, f) + setattr(Unfinished, f.__name__, f) @UnfinishedSequenceFunc def allow_suppressing_display(self): - self.always_display = False - return self # return self to allow chaining with other calls and throwing - # straight into a return statement etc + self.always_display = False + return self # return self to allow chaining with other calls and throwing + # straight into a return statement etc # later, we will overload == for unfinished sequences, such that it always # returns another unfinished sequence. unfortunately this creates the following @@ -595,15 +602,15 @@ def allow_suppressing_display(self): def guarded_compare(seq1, seq2): - if isinstance(seq1, UnfinishedSequence) \ - or isinstance(seq2, UnfinishedSequence): - return seq1 is seq2 - return seq1 == seq2 + if isinstance(seq1, UnfinishedSequence) \ + or isinstance(seq2, UnfinishedSequence): + return seq1 is seq2 + return seq1 == seq2 def guarded_contains(ll, a): - if isinstance(a, Unfinished): - return True in [(a is e) for e in ll] - else: - ll = [e for e in ll if not isinstance(e, Unfinished)] - return a in ll + if isinstance(a, Unfinished): + return True in [(a is e) for e in ll] + else: + ll = [e for e in ll if not isinstance(e, Unfinished)] + return a in ll diff --git a/RASP_support/REPL.py b/RASP_support/REPL.py index 50cd3a4..3ed347b 100644 --- a/RASP_support/REPL.py +++ b/RASP_support/REPL.py @@ -6,7 +6,7 @@ from .Environment import Environment, UndefinedVariable, ReservedName from .FunctionalSupport import UnfinishedSequence, UnfinishedSelect, Unfinished from .Evaluator import Evaluator, NamedVal, NamedValList, JustVal, \ - RASPFunction, ArgsError, RASPTypeError, RASPValueError + RASPFunction, ArgsError, RASPTypeError, RASPValueError from .Support import Select, Sequence, lazy_type_check from termcolor import colored from .colors import error_color, values_color, general_color @@ -15,586 +15,593 @@ class ResultToPrint: - def __init__(self, res, to_print): - self.res, self.print = res, to_print + def __init__(self, res, to_print): + self.res, self.print = res, to_print class LazyPrint: - def __init__(self, *a, **kw): - self.a, self.kw = a, kw + def __init__(self, *a, **kw): + self.a, self.kw = a, kw - def print(self): - print(*self.a, **self.kw) + def print(self): + print(*self.a, **self.kw) class StopException(Exception): - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() DEBUG = False def debprint(*a, **kw): - if DEBUG: - coloredprint(*a, **kw) + if DEBUG: + coloredprint(*a, **kw) class ReturnExample: - def __init__(self, subset): - self.subset = subset + def __init__(self, subset): + self.subset = subset class LoadError(Exception): - def __init__(self, msg): - super().__init__(msg) + def __init__(self, msg): + super().__init__(msg) def is_comment(line): - if not isinstance(line, str): - return False - return line.strip().startswith("#") + if not isinstance(line, str): + return False + return line.strip().startswith("#") def formatstr(res): - if isinstance(res, str): - return "\""+res+"\"" - return str(res) + if isinstance(res, str): + return "\"" + res + "\"" + return str(res) class REPL: - def __init__(self): - self.env = Environment(name="console") - self.sequence_running_example = "hello" - self.selector_running_example = "hello" - self.sequence_prints_verbose = False - self.show_sequence_examples = True - self.show_selector_examples = True - self.results_to_print = [] - self.print_welcome() - self.load_base_libraries_and_make_base_env() - - def load_base_libraries_and_make_base_env(self): - self.silent = True - # base env: the env from which every load begins - self.base_env = self.env.snapshot() - # bootstrap base_env with current (basically empty except indices etc) - # env, then load the base libraries to build the actual base env - # make the library-loaded variables and functions not-overwriteable - self.env.storing_in_constants = True - for lib in ["RASP_support/rasplib"]: - self.run_given_line("load \"" + lib + "\";") - self.base_env = self.env.snapshot() - self.env.storing_in_constants = False - self.run_given_line("tokens=tokens_str;") - self.base_env = self.env.snapshot() - self.silent = False - - def set_running_example(self, example, which="both"): - if which in ["both", ENCODER_NAME]: - self.sequence_running_example = example - if which in ["both", "selector"]: - self.selector_running_example = example - - def print_welcome(self): - print(colored("RASP 0.1", general_color)) - print(colored("running example is:", general_color), - colored(self.sequence_running_example, values_color)) - - def print_just_val(self, justval): - val = justval.val - if None is val: - return - if isinstance(val, Select): - print(colored("\t = ", general_color)) - print_select(val.created_from_input, val) - elif isinstance(val, Sequence) and self.sequence_prints_verbose: - print(colored("\t = ", general_color), end="") - print_seq(val.created_from_input, val, still_on_prev_line=True) - else: - print(colored("\t = ", general_color), - colored(str(val).replace("\n", "\n\t\t\t"), values_color)) - - def print_named_val(self, name, val, ntabs=0, extra_first_pref=""): - pref = "\t"*ntabs - if (None is name) and isinstance(val, Unfinished): - name = val.name - if isinstance(val, UnfinishedSequence): - print(pref, - colored(extra_first_pref, general_color), - colored(" "+ENCODER_NAME+":", general_color), - colored(name, general_color)) - if self.show_sequence_examples: - if self.sequence_prints_verbose: - print(colored(f"{pref} \t Example:", general_color), - end="") - optional_exampledesc =\ - colored(name + "(", general_color) +\ - colored(formatstr(self.sequence_running_example), - values_color) +\ - colored(") =", general_color) - print_seq(self.selector_running_example, - val.call(self.sequence_running_example), - still_on_prev_line=True, - extra_pref=pref, - lastpref_if_shortprint=optional_exampledesc) - else: - print(colored(f"{pref} \t Example: {name}(", - general_color) + - colored(formatstr(self.sequence_running_example), values_color) + - colored(") =", general_color), - val.call(self.sequence_running_example)) - elif isinstance(val, UnfinishedSelect): - print(colored(pref, general_color), - colored(extra_first_pref, general_color), - colored(f" selector: {name}", general_color)) - if self.show_selector_examples: - print(colored(f"{pref} \t Example:", general_color)) - print_select(self.selector_running_example, val.call( - self.selector_running_example), extra_pref=pref) - elif isinstance(val, RASPFunction): - print(colored(f"{pref} {extra_first_pref} ", general_color) + - colored(str(val), general_color)) - elif isinstance(val, list): - named = " list: "+((name+" = ") if name is not None else "") - print(colored(f"{pref} {extra_first_pref} {named}", - general_color), end="") - flat = True not in [isinstance(v, list) or isinstance( - v, dict) or isinstance(v, Unfinished) for v in val] - if flat: - print(colored(val, values_color)) - else: - print(colored(f"{pref} [", general_color)) - for v in val: - self.print_named_val(None, v, ntabs=ntabs+2) - print(colored(str(pref) + " "*(len(named) +2) + "]", - general_color)) - elif isinstance(val, dict): - named = " dict: "+((name+" = ") if name is not None else "") - print(colored(f"{pref} {extra_first_pref} {named}", - general_color), end="") - flat = True not in [isinstance(val[v], list) or isinstance( - val[v], dict) or isinstance(val[v], Unfinished) for v in val] - if flat: - print(colored(val, values_color)) - else: - print(colored(str(pref) + " {", general_color)) - for v in val: - self.print_named_val(None, val[v], ntabs=ntabs + 3, - extra_first_pref=formatstr(v) + " : ") - print(colored(str(pref) + " "*(len(named) + 2) + "}", - general_color)) - - else: - namestr = (name + " = ") if name is not None else "" - print(colored(f"{pref} value: {namestr}", general_color), - colored(formatstr(val), values_color)) - - def print_example(self, nres): - if nres.subset in ["both", ENCODER_NAME]: - print(colored("\t"+ENCODER_NAME+" example:", general_color), - colored(formatstr(self.sequence_running_example), values_color)) - if nres.subset in ["both", "selector"]: - print(colored("\tselector example:", general_color), - colored(formatstr(self.selector_running_example), values_color)) - - def print_result(self, rp): - if self.silent: - return - if isinstance(rp, LazyPrint): - return rp.print() - # a list of multiple ResultToPrint s -- probably the result of a - # multi-assignment - if isinstance(rp, list): - for v in rp: - self.print_result(v) - return - if not rp.print: - return - res = rp.res - if isinstance(res, NamedVal): - self.print_named_val(res.name, res.val) - elif isinstance(res, ReturnExample): - self.print_example(res) - elif isinstance(res, JustVal): - self.print_just_val(res) - - def evaluate_replstatement(self, ast): - if ast.setExample(): - return ResultToPrint(self.setExample(ast.setExample()), False) - if ast.showExample(): - return ResultToPrint(self.showExample(ast.showExample()), True) - if ast.toggleExample(): - return ResultToPrint(self.toggleExample(ast.toggleExample()), - False) - if ast.toggleSeqVerbose(): - return ResultToPrint(self.toggleSeqVerbose(ast.toggleSeqVerbose()), - False) - if ast.exit(): - raise StopException() - - def toggleSeqVerbose(self, ast): - switch = ast.switch.text - self.sequence_prints_verbose = switch == "on" - - def toggleExample(self, ast): - subset = ast.subset - subset = "both" if not subset else subset.text - switch = ast.switch.text - examples_on = switch == "on" - if subset in ["both", ENCODER_NAME]: - self.show_sequence_examples = examples_on - if subset in ["both", "selector"]: - self.show_selector_examples = examples_on - - def showExample(self, ast): - subset = ast.subset - subset = "both" if not subset else subset.text - return ReturnExample(subset) - - def setExample(self, ast): - example = Evaluator(self.env, self).evaluateExpr(ast.example) - if not isinstance(example, Iterable): - raise RASPTypeError("example not iterable: "+str(example)) - subset = ast.subset - subset = "both" if not subset else subset.text - self.set_running_example(example, subset) - return ReturnExample(subset) - - def loadFile(self, ast, calling_env=None): - if None is calling_env: - calling_env = self.env - libname = ast.filename.text[1:-1] - filename = libname + ".rasp" - try: - with open(filename, "r") as f: - prev_example_settings = self.show_sequence_examples, \ - self.show_selector_examples - self.show_sequence_examples = False - self.show_selector_examples = False - self.run(fromfile=f, - env=Environment(name=libname, - parent_env=self.base_env, - stealing_env=calling_env), - store_prints=True) - self.filter_and_dump_prints() - self.show_sequence_examples, self.show_selector_examples = \ - prev_example_settings - except FileNotFoundError: - raise LoadError("could not find file: "+filename) - - def get_tree(self, fromfile=None): - try: - return LineReader(fromfile=fromfile).get_input_tree() - except AntlrException as e: - print(colored(f"\t!! antlr exception: {e.msg} \t-- ignoring input", - error_color)) - return None - - def run_given_line(self, line): - try: - tree = LineReader(given_line=line).get_input_tree() - if isinstance(tree, Stop): - return None - rp = self.evaluate_tree(tree) - if isinstance(rp, LazyPrint): - # error messages get raised, but ultimately have to be printed - # somewhere if not caught? idk - rp.print() - except AntlrException as e: - print(colored(f"\t!! REPL failed to run initiating line: {line}", - error_color)) - print(colored(f"\t --got antlr exception: {e.msg}", - error_color)) - return None - - def assigned_to_top(self, res, env): - if env is self.env: - return True - # we are now definitely inside some file, the question is whether we - # have taken the result and kept it in the top level too, i.e., whether - # we have imported a non-private value. checking whether it is also in - # self.env, even identical, will not tell us much as it may have been - # here and the same already. so we have to replicate the logic here. - if not isinstance(res, NamedVal): - return False # only namedvals get set to begin with - if res.name.startswith("_") or (res.name == "out"): - return False - return True - - def evaluate_tree(self, tree, env=None): - if None is env: - env = self.env # otherwise, can pass custom env - # (e.g. when loading from a file, make env for that file, - # to keep that file's private (i.e. underscore-prefixed) variables - # to itself) - if None is tree: - return ResultToPrint(None, False) - try: - if tree.replstatement(): - return self.evaluate_replstatement(tree.replstatement()) - elif tree.raspstatement(): - res = Evaluator(env, self).evaluate(tree.raspstatement()) - if isinstance(res, NamedValList): - return [ResultToPrint(r, self.assigned_to_top(r, env)) for - r in res.nvs] - return ResultToPrint(res, self.assigned_to_top(res, env)) - except (UndefinedVariable, ReservedName) as e: - return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}", error_color)) - except NotImplementedError: - return LazyPrint( - colored(f"not implemented this command yet! ignoring", error_color)) - except (ArgsError, RASPTypeError, LoadError, RASPValueError) as e: - return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}", error_color)) - # if not replstatement or raspstatement, then comment - return ResultToPrint(None, False) - - def filter_and_dump_prints(self): - # TODO: some error messages are still rising up and getting printed - # before reaching this position :( - def filter_named_val_reps(rps): - # do the filtering. no namedvallists here - those are converted - # into a list of ResultToPrint s containing NamedVal s immediately - # after receiving them in evaluate_tree - res = [] - names = set() - # go backwards - want to print the last occurence of each named - # item, not first, so filter works backwards - for r in rps[::-1]: - if isinstance(r.res, NamedVal): - if r.res.name in names: - continue - names.add(r.res.name) - res.append(r) - return res[::-1] # flip back forwards - - if True not in [isinstance(v, LazyPrint) for - v in self.results_to_print]: - self.results_to_print = filter_named_val_reps( - self.results_to_print) - # if isinstance(res,NamedVal): - # self.print_named_val(res.name,res.val) - # - # print all that needs to be printed: - for r in self.results_to_print: - if isinstance(r, LazyPrint): - r.print() - else: - self.print_result(r) - # clear the list - self.results_to_print = [] - - def run(self, fromfile=None, env=None, store_prints=False): - def careful_print(*a, **kw): - if store_prints: - self.results_to_print.append(LazyPrint(*a, **kw)) - else: - print(*a, **kw) - while True: - try: - tree = self.get_tree(fromfile) - if isinstance(tree, Stop): - break - rp = self.evaluate_tree(tree, env) - if store_prints: - if isinstance(rp, list): - # multiple results given - a multi-assignment - self.results_to_print += rp - else: - self.results_to_print.append(rp) - else: - self.print_result(rp) - except RASPTypeError as e: - msg = "\t!!statement executed, but result fails on evaluation:" - msg += "\n\t\t" - toprint = colored(f"{msg} {e}", error_color) - careful_print(toprint) - except EOFError: - careful_print("") - break - except StopException: - break - except KeyboardInterrupt: - careful_print("") # makes newline - except Exception as e: - if DEBUG: - raise e - careful_print(colored(f"something went wrong: {e}", - error_color)) + def __init__(self): + self.env = Environment(name="console") + self.sequence_running_example = "hello" + self.selector_running_example = "hello" + self.sequence_prints_verbose = False + self.show_sequence_examples = True + self.show_selector_examples = True + self.results_to_print = [] + self.print_welcome() + self.load_base_libraries_and_make_base_env() + + def load_base_libraries_and_make_base_env(self): + self.silent = True + # base env: the env from which every load begins + self.base_env = self.env.snapshot() + # bootstrap base_env with current (basically empty except indices etc) + # env, then load the base libraries to build the actual base env + # make the library-loaded variables and functions not-overwriteable + self.env.storing_in_constants = True + for lib in ["RASP_support/rasplib"]: + self.run_given_line("load \"" + lib + "\";") + self.base_env = self.env.snapshot() + self.env.storing_in_constants = False + self.run_given_line("tokens=tokens_str;") + self.base_env = self.env.snapshot() + self.silent = False + + def set_running_example(self, example, which="both"): + if which in ["both", ENCODER_NAME]: + self.sequence_running_example = example + if which in ["both", "selector"]: + self.selector_running_example = example + + def print_welcome(self): + print(colored("RASP 0.1", general_color)) + print(colored("running example is:", general_color), + colored(self.sequence_running_example, values_color)) + + def print_just_val(self, justval): + val = justval.val + if None is val: + return + if isinstance(val, Select): + print(colored("\t = ", general_color)) + print_select(val.created_from_input, val) + elif isinstance(val, Sequence) and self.sequence_prints_verbose: + print(colored("\t = ", general_color), end="") + print_seq(val.created_from_input, val, still_on_prev_line=True) + else: + print(colored("\t = ", general_color), + colored(str(val).replace("\n", "\n\t\t\t"), values_color)) + + def print_named_val(self, name, val, ntabs=0, extra_first_pref=""): + pref = "\t" * ntabs + if (None is name) and isinstance(val, Unfinished): + name = val.name + if isinstance(val, UnfinishedSequence): + print(pref, + colored(extra_first_pref, general_color), + colored(" " + ENCODER_NAME + ":", general_color), + colored(name, general_color)) + if self.show_sequence_examples: + if self.sequence_prints_verbose: + print(colored(f"{pref} \t Example:", general_color), + end="") + optional_exampledesc =\ + colored(name + "(", general_color) +\ + colored(formatstr(self.sequence_running_example), + values_color) +\ + colored(") =", general_color) + print_seq(self.selector_running_example, + val.call(self.sequence_running_example), + still_on_prev_line=True, + extra_pref=pref, + lastpref_if_shortprint=optional_exampledesc) + else: + print(colored(f"{pref} \t Example: {name}(", + general_color) + + colored(formatstr(self.sequence_running_example), + values_color) + + colored(") =", general_color), + val.call(self.sequence_running_example)) + elif isinstance(val, UnfinishedSelect): + print(colored(pref, general_color), + colored(extra_first_pref, general_color), + colored(f" selector: {name}", general_color)) + if self.show_selector_examples: + print(colored(f"{pref} \t Example:", general_color)) + print_select(self.selector_running_example, val.call( + self.selector_running_example), extra_pref=pref) + elif isinstance(val, RASPFunction): + print(colored(f"{pref} {extra_first_pref} ", general_color) + + colored(str(val), general_color)) + elif isinstance(val, list): + named = " list: " + ((name + " = ") if name is not None else "") + print(colored(f"{pref} {extra_first_pref} {named}", + general_color), end="") + flat = True not in [isinstance(v, list) or isinstance( + v, dict) or isinstance(v, Unfinished) for v in val] + if flat: + print(colored(val, values_color)) + else: + print(colored(f"{pref} [", general_color)) + for v in val: + self.print_named_val(None, v, ntabs=ntabs + 2) + print(colored(str(pref) + " " * (len(named) + 2) + "]", + general_color)) + elif isinstance(val, dict): + named = " dict: " + ((name + " = ") if name is not None else "") + print(colored(f"{pref} {extra_first_pref} {named}", + general_color), end="") + flat = True not in [isinstance(val[v], list) or isinstance( + val[v], dict) or isinstance(val[v], Unfinished) for v in val] + if flat: + print(colored(val, values_color)) + else: + print(colored(str(pref) + " {", general_color)) + for v in val: + self.print_named_val(None, val[v], ntabs=ntabs + 3, + extra_first_pref=formatstr(v) + " : ") + print(colored(str(pref) + " " * (len(named) + 2) + "}", + general_color)) + + else: + namestr = (name + " = ") if name is not None else "" + print(colored(f"{pref} value: {namestr}", general_color), + colored(formatstr(val), values_color)) + + def print_example(self, nres): + if nres.subset in ["both", ENCODER_NAME]: + print(colored("\t" + ENCODER_NAME + " example:", general_color), + colored(formatstr(self.sequence_running_example), + values_color)) + if nres.subset in ["both", "selector"]: + print(colored("\tselector example:", general_color), + colored(formatstr(self.selector_running_example), + values_color)) + + def print_result(self, rp): + if self.silent: + return + if isinstance(rp, LazyPrint): + return rp.print() + # a list of multiple ResultToPrint s -- probably the result of a + # multi-assignment + if isinstance(rp, list): + for v in rp: + self.print_result(v) + return + if not rp.print: + return + res = rp.res + if isinstance(res, NamedVal): + self.print_named_val(res.name, res.val) + elif isinstance(res, ReturnExample): + self.print_example(res) + elif isinstance(res, JustVal): + self.print_just_val(res) + + def evaluate_replstatement(self, ast): + if ast.setExample(): + return ResultToPrint(self.setExample(ast.setExample()), False) + if ast.showExample(): + return ResultToPrint(self.showExample(ast.showExample()), True) + if ast.toggleExample(): + return ResultToPrint(self.toggleExample(ast.toggleExample()), + False) + if ast.toggleSeqVerbose(): + return ResultToPrint(self.toggleSeqVerbose(ast.toggleSeqVerbose()), + False) + if ast.exit(): + raise StopException() + + def toggleSeqVerbose(self, ast): + switch = ast.switch.text + self.sequence_prints_verbose = switch == "on" + + def toggleExample(self, ast): + subset = ast.subset + subset = "both" if not subset else subset.text + switch = ast.switch.text + examples_on = switch == "on" + if subset in ["both", ENCODER_NAME]: + self.show_sequence_examples = examples_on + if subset in ["both", "selector"]: + self.show_selector_examples = examples_on + + def showExample(self, ast): + subset = ast.subset + subset = "both" if not subset else subset.text + return ReturnExample(subset) + + def setExample(self, ast): + example = Evaluator(self.env, self).evaluateExpr(ast.example) + if not isinstance(example, Iterable): + raise RASPTypeError("example not iterable: " + str(example)) + subset = ast.subset + subset = "both" if not subset else subset.text + self.set_running_example(example, subset) + return ReturnExample(subset) + + def loadFile(self, ast, calling_env=None): + if None is calling_env: + calling_env = self.env + libname = ast.filename.text[1:-1] + filename = libname + ".rasp" + try: + with open(filename, "r") as f: + prev_example_settings = self.show_sequence_examples, \ + self.show_selector_examples + self.show_sequence_examples = False + self.show_selector_examples = False + self.run(fromfile=f, + env=Environment(name=libname, + parent_env=self.base_env, + stealing_env=calling_env), + store_prints=True) + self.filter_and_dump_prints() + self.show_sequence_examples, self.show_selector_examples = \ + prev_example_settings + except FileNotFoundError: + raise LoadError("could not find file: " + filename) + + def get_tree(self, fromfile=None): + try: + return LineReader(fromfile=fromfile).get_input_tree() + except AntlrException as e: + print(colored(f"\t!! antlr exception: {e.msg} \t-- ignoring input", + error_color)) + return None + + def run_given_line(self, line): + try: + tree = LineReader(given_line=line).get_input_tree() + if isinstance(tree, Stop): + return None + rp = self.evaluate_tree(tree) + if isinstance(rp, LazyPrint): + # error messages get raised, but ultimately have to be printed + # somewhere if not caught? idk + rp.print() + except AntlrException as e: + print(colored(f"\t!! REPL failed to run initiating line: {line}", + error_color)) + print(colored(f"\t --got antlr exception: {e.msg}", + error_color)) + return None + + def assigned_to_top(self, res, env): + if env is self.env: + return True + # we are now definitely inside some file, the question is whether we + # have taken the result and kept it in the top level too, i.e., whether + # we have imported a non-private value. checking whether it is also in + # self.env, even identical, will not tell us much as it may have been + # here and the same already. so we have to replicate the logic here. + if not isinstance(res, NamedVal): + return False # only namedvals get set to begin with + if res.name.startswith("_") or (res.name == "out"): + return False + return True + + def evaluate_tree(self, tree, env=None): + if None is env: + env = self.env # otherwise, can pass custom env + # (e.g. when loading from a file, make env for that file, + # to keep that file's private (i.e. underscore-prefixed) variables + # to itself) + if None is tree: + return ResultToPrint(None, False) + try: + if tree.replstatement(): + return self.evaluate_replstatement(tree.replstatement()) + elif tree.raspstatement(): + res = Evaluator(env, self).evaluate(tree.raspstatement()) + if isinstance(res, NamedValList): + return [ResultToPrint(r, self.assigned_to_top(r, env)) for + r in res.nvs] + return ResultToPrint(res, self.assigned_to_top(res, env)) + except (UndefinedVariable, ReservedName) as e: + return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}", + error_color)) + except NotImplementedError: + return LazyPrint( + colored(f"not implemented this command yet! ignoring", + error_color)) + except (ArgsError, RASPTypeError, LoadError, RASPValueError) as e: + return LazyPrint(colored(f"\t\t!!ignoring input:\n\t {e}", + error_color)) + # if not replstatement or raspstatement, then comment + return ResultToPrint(None, False) + + def filter_and_dump_prints(self): + # TODO: some error messages are still rising up and getting printed + # before reaching this position :( + def filter_named_val_reps(rps): + # do the filtering. no namedvallists here - those are converted + # into a list of ResultToPrint s containing NamedVal s immediately + # after receiving them in evaluate_tree + res = [] + names = set() + # go backwards - want to print the last occurence of each named + # item, not first, so filter works backwards + for r in rps[::-1]: + if isinstance(r.res, NamedVal): + if r.res.name in names: + continue + names.add(r.res.name) + res.append(r) + return res[::-1] # flip back forwards + + if True not in [isinstance(v, LazyPrint) for + v in self.results_to_print]: + self.results_to_print = filter_named_val_reps( + self.results_to_print) + # if isinstance(res,NamedVal): + # self.print_named_val(res.name,res.val) + # + # print all that needs to be printed: + for r in self.results_to_print: + if isinstance(r, LazyPrint): + r.print() + else: + self.print_result(r) + # clear the list + self.results_to_print = [] + + def run(self, fromfile=None, env=None, store_prints=False): + def careful_print(*a, **kw): + if store_prints: + self.results_to_print.append(LazyPrint(*a, **kw)) + else: + print(*a, **kw) + while True: + try: + tree = self.get_tree(fromfile) + if isinstance(tree, Stop): + break + rp = self.evaluate_tree(tree, env) + if store_prints: + if isinstance(rp, list): + # multiple results given - a multi-assignment + self.results_to_print += rp + else: + self.results_to_print.append(rp) + else: + self.print_result(rp) + except RASPTypeError as e: + msg = "\t!!statement executed, but result fails on evaluation:" + msg += "\n\t\t" + toprint = colored(f"{msg} {e}", error_color) + careful_print(toprint) + except EOFError: + careful_print("") + break + except StopException: + break + except KeyboardInterrupt: + careful_print("") # makes newline + except Exception as e: + if DEBUG: + raise e + careful_print(colored(f"something went wrong: {e}", + error_color)) class AntlrException(Exception): - def __init__(self, msg): - self.msg = msg + def __init__(self, msg): + self.msg = msg class InputNotFinished(Exception): - def __init__(self): - pass + def __init__(self): + pass class MyErrorListener(ErrorListener): - def __init__(self): - super(MyErrorListener, self).__init__() - - def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): - if offendingSymbol and offendingSymbol.text == "": - raise InputNotFinished() - if msg.startswith("missing ';' at"): - raise InputNotFinished() - # TODO: why did this do nothing? - # if "mismatched input" in msg: - # a = str(offendingSymbol) - # b = a[a.find("=")+2:] - # c = b[:b.find(",<")-1] - ae = AntlrException(msg) - ae.recognizer = recognizer - ae.offendingSymbol = offendingSymbol - ae.line = line - ae.column = column - ae.msg = msg - ae.e = e - raise ae - - # def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, - # ambigAlts, configs): - # raise AntlrException("ambiguity") - - # def reportAttemptingFullContext(self, recognizer, dfa, startIndex, - # stopIndex, conflictingAlts, configs): - # we're ok with this: happens with func defs it seems - - # def reportContextSensitivity(self, recognizer, dfa, startIndex, - # stopIndex, prediction, configs): - # we're ok with this: happens with func defs it seems + def __init__(self): + super(MyErrorListener, self).__init__() + + def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): + if offendingSymbol and offendingSymbol.text == "": + raise InputNotFinished() + if msg.startswith("missing ';' at"): + raise InputNotFinished() + # TODO: why did this do nothing? + # if "mismatched input" in msg: + # a = str(offendingSymbol) + # b = a[a.find("=")+2:] + # c = b[:b.find(",<")-1] + ae = AntlrException(msg) + ae.recognizer = recognizer + ae.offendingSymbol = offendingSymbol + ae.line = line + ae.column = column + ae.msg = msg + ae.e = e + raise ae + + # def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, + # ambigAlts, configs): + # raise AntlrException("ambiguity") + + # def reportAttemptingFullContext(self, recognizer, dfa, startIndex, + # stopIndex, conflictingAlts, configs): + # we're ok with this: happens with func defs it seems + + # def reportContextSensitivity(self, recognizer, dfa, startIndex, + # stopIndex, prediction, configs): + # we're ok with this: happens with func defs it seems class Stop: - def __init__(self): - pass + def __init__(self): + pass class LineReader: - def __init__(self, prompt=">>", fromfile=None, given_line=None): - self.fromfile = fromfile - self.given_line = given_line - self.prompt = prompt + " " - self.cont_prompt = "."*len(prompt)+" " - - def str_to_antlr_parser(self, s): - antlrinput = InputStream(s) - lexer = RASPLexer(antlrinput) - lexer.removeErrorListeners() - lexer.addErrorListener(MyErrorListener()) - stream = CommonTokenStream(lexer) - parser = RASPParser(stream) - parser.removeErrorListeners() - parser.addErrorListener(MyErrorListener()) - return parser - - def read_line(self, continuing=False, nest_depth=0): - prompt = self.cont_prompt if continuing else self.prompt - if self.fromfile is not None: - res = self.fromfile.readline() - # python files return "" on last line (as opposed to "\n" on empty - # lines) - if not res: - return Stop() - return res - if self.given_line is not None: - res = self.given_line - self.given_line = Stop() - return res - else: - return input(prompt+(" "*nest_depth)) - - def get_input_tree(self): - pythoninput = "" - multiline = False - while True: - nest_depth = pythoninput.split().count("def") - newinput = self.read_line(continuing=multiline, - nest_depth=nest_depth) - if isinstance(newinput, Stop): # input stream ended - return Stop() - if is_comment(newinput): - # don't let comments get in and ruin things somehow - newinput = "" - # don't replace newlines here! this is how in-function comments get - # broken - pythoninput += newinput - parser = self.str_to_antlr_parser(pythoninput) - try: - res = parser.r().statement() - if isinstance(res, list): - # TODO: this seems to happen when there's ambiguity. figure - # out what is going on!! - assert len(res) == 1 - res = res[0] - return res - except InputNotFinished: - multiline = True - pythoninput += " " + def __init__(self, prompt=">>", fromfile=None, given_line=None): + self.fromfile = fromfile + self.given_line = given_line + self.prompt = prompt + " " + self.cont_prompt = "." * len(prompt) + " " + + def str_to_antlr_parser(self, s): + antlrinput = InputStream(s) + lexer = RASPLexer(antlrinput) + lexer.removeErrorListeners() + lexer.addErrorListener(MyErrorListener()) + stream = CommonTokenStream(lexer) + parser = RASPParser(stream) + parser.removeErrorListeners() + parser.addErrorListener(MyErrorListener()) + return parser + + def read_line(self, continuing=False, nest_depth=0): + prompt = self.cont_prompt if continuing else self.prompt + if self.fromfile is not None: + res = self.fromfile.readline() + # python files return "" on last line (as opposed to "\n" on empty + # lines) + if not res: + return Stop() + return res + if self.given_line is not None: + res = self.given_line + self.given_line = Stop() + return res + else: + return input(prompt + (" " * nest_depth)) + + def get_input_tree(self): + pythoninput = "" + multiline = False + while True: + nest_depth = pythoninput.split().count("def") + newinput = self.read_line(continuing=multiline, + nest_depth=nest_depth) + if isinstance(newinput, Stop): # input stream ended + return Stop() + if is_comment(newinput): + # don't let comments get in and ruin things somehow + newinput = "" + # don't replace newlines here! this is how in-function comments get + # broken + pythoninput += newinput + parser = self.str_to_antlr_parser(pythoninput) + try: + res = parser.r().statement() + if isinstance(res, list): + # TODO: this seems to happen when there's ambiguity. figure + # out what is going on!! + assert len(res) == 1 + res = res[0] + return res + except InputNotFinished: + multiline = True + pythoninput += " " def print_seq(example, seq, still_on_prev_line=False, extra_pref="", - lastpref_if_shortprint=""): - if len(set(seq.get_vals())) == 1: - print(extra_pref if not still_on_prev_line else "", - lastpref_if_shortprint, - colored(str(seq), values_color), end=" ") - # when there is only one value, it's nicer to just print that than the - # full list, verbosity be damned - print(colored("[skipped full display: identical values]", general_color)) - return - if still_on_prev_line: - print("") - - seq = seq.get_vals() - - def cleanboolslist(seq): - if isinstance(seq[0], bool): - tstr = "T" if seq.count(True) <= seq.count(False) else "" - fstr = "F" if seq.count(False) <= seq.count(True) else "" - return [tstr if v else fstr for v in seq] - else: - return seq - - example = cleanboolslist(example) - seqtype = lazy_type_check(seq) - seq = cleanboolslist(seq) - example = [str(v) for v in example] - seq = [str(v) for v in seq] - maxlen = max(len(v) for v in example+seq) - - def neatline(seq): - def padded(s): - return " "*(maxlen-len(s))+s - return " ".join(padded(v) for v in seq) - print(extra_pref, colored("\t\tinput: ", general_color), - colored(neatline(example), values_color), "\t", - colored("("+lazy_type_check(example)+"s)", general_color)) - print(extra_pref, colored("\t\toutput: ", general_color), - colored(neatline(seq), values_color), "\t", - colored("("+seqtype+"s)", general_color)) + lastpref_if_shortprint=""): + if len(set(seq.get_vals())) == 1: + print(extra_pref if not still_on_prev_line else "", + lastpref_if_shortprint, + colored(str(seq), values_color), end=" ") + # when there is only one value, it's nicer to just print that than the + # full list, verbosity be damned + print(colored("[skipped full display: identical values]", + general_color)) + return + if still_on_prev_line: + print("") + + seq = seq.get_vals() + + def cleanboolslist(seq): + if isinstance(seq[0], bool): + tstr = "T" if seq.count(True) <= seq.count(False) else "" + fstr = "F" if seq.count(False) <= seq.count(True) else "" + return [tstr if v else fstr for v in seq] + else: + return seq + + example = cleanboolslist(example) + seqtype = lazy_type_check(seq) + seq = cleanboolslist(seq) + example = [str(v) for v in example] + seq = [str(v) for v in seq] + maxlen = max(len(v) for v in example + seq) + + def neatline(seq): + def padded(s): + return " " * (maxlen - len(s)) + s + return " ".join(padded(v) for v in seq) + print(extra_pref, colored("\t\tinput: ", general_color), + colored(neatline(example), values_color), "\t", + colored("(" + lazy_type_check(example) + "s)", general_color)) + print(extra_pref, colored("\t\toutput: ", general_color), + colored(neatline(seq), values_color), "\t", + colored("(" + seqtype + "s)", general_color)) def print_select(example, select, extra_pref=""): - # .replace("\n","\n\t\t\t") - def nice_matrix_line(m): - return " ".join("1" if v else " " for v in m) - print(colored(extra_pref, general_color), "\t\t\t ", - colored(" ".join(str(v) for v in example), values_color)) - matrix = select.get_vals() - [print(colored(extra_pref, general_color), "\t\t\t", - colored(v, values_color), - colored("|", general_color), - colored(nice_matrix_line(matrix[m]), values_color)) - for v, m in zip(example, matrix)] + # .replace("\n","\n\t\t\t") + def nice_matrix_line(m): + return " ".join("1" if v else " " for v in m) + print(colored(extra_pref, general_color), "\t\t\t ", + colored(" ".join(str(v) for v in example), values_color)) + matrix = select.get_vals() + [print(colored(extra_pref, general_color), "\t\t\t", + colored(v, values_color), + colored("|", general_color), + colored(nice_matrix_line(matrix[m]), values_color)) + for v, m in zip(example, matrix)] if __name__ == "__main__": - REPL().run() + REPL().run() # (set debug in this file to True) @@ -605,10 +612,10 @@ def nice_matrix_line(m): # import REPL # REPL.runner() def runner(): - a = REPL() - try: - a.run() - except Exception as e: - print(e) - return a, e - return a, None + a = REPL() + try: + a.run() + except Exception as e: + print(e) + return a, e + return a, None diff --git a/RASP_support/Sugar.py b/RASP_support/Sugar.py index da80d51..248c22e 100644 --- a/RASP_support/Sugar.py +++ b/RASP_support/Sugar.py @@ -11,44 +11,44 @@ def _apply_unary_op(self, f): - return zipmap(self, f) + return zipmap(self, f) def _apply_binary_op(self, other, f): - def seq_and_other_op(self, other, f): - return zipmap(self, lambda a: f(a, other)) + def seq_and_other_op(self, other, f): + return zipmap(self, lambda a: f(a, other)) - def seq_and_seq_op(self, other_seq, f): - return zipmap((self, other_seq), f) - if isinstance(other, _UnfinishedSequence): - return seq_and_seq_op(self, other, f) - else: - return seq_and_other_op(self, other, f) + def seq_and_seq_op(self, other_seq, f): + return zipmap((self, other_seq), f) + if isinstance(other, _UnfinishedSequence): + return seq_and_seq_op(self, other, f) + else: + return seq_and_other_op(self, other, f) add_ops(_UnfinishedSequence, _apply_unary_op, _apply_binary_op) def _addname(seq, name, default_name, always_display_when_named=True): - if name is None: - res = seq.setname(default_name, - always_display_when_named=always_display_when_named) - res = res.allow_suppressing_display() - else: - res = seq.setname(name, - always_display_when_named=always_display_when_named) - return res + if name is None: + res = seq.setname(default_name, + always_display_when_named=always_display_when_named) + res = res.allow_suppressing_display() + else: + res = seq.setname(name, + always_display_when_named=always_display_when_named) + return res full_s = select((), (), lambda: True, name="full average", - compare_string="full average") + compare_string="full average") def tplconst(v, name=None): - return _addname(zipmap((), lambda: v), name, "constant: " + str(v), - always_display_when_named=False).mark_as_constant() - # always_display_when_named = False : constants aren't worth displaying, - # but still going to name them in background, in case I change my mind + return _addname(zipmap((), lambda: v), name, "constant: " + str(v), + always_display_when_named=False).mark_as_constant() + # always_display_when_named = False : constants aren't worth displaying, + # but still going to name them in background, in case I change my mind # allow suppressing display for bool, not, and, or : all of these would have # been boring operators if only python let me overload them @@ -58,45 +58,45 @@ def tplconst(v, name=None): def toseq(seq): - if not isinstance(seq, _UnfinishedSequence): - seq = tplconst(seq, str(seq)) - return seq + if not isinstance(seq, _UnfinishedSequence): + seq = tplconst(seq, str(seq)) + return seq def asbool(seq): - res = zipmap(seq, lambda a: bool(a)) - return _addname(res, None, "bool(" + seq.name + ")") - # would do res = seq==True but it seems this has different behaviour to - # bool eg 'bool(2)' is True but '2==True' returns False + res = zipmap(seq, lambda a: bool(a)) + return _addname(res, None, "bool(" + seq.name + ")") + # would do res = seq==True but it seems this has different behaviour to + # bool eg 'bool(2)' is True but '2==True' returns False def tplnot(seq, name=None): - # this one does correct conversion using asbool and then we really can just - # do == False - pep8hack = False # this avoids violating E712 of PEP8 - res = asbool(seq) == pep8hack - return _addname(res, name, "( not " + str(seq.name) + " )") + # this one does correct conversion using asbool and then we really can just + # do == False + pep8hack = False # this avoids violating E712 of PEP8 + res = asbool(seq) == pep8hack + return _addname(res, name, "( not " + str(seq.name) + " )") def _num_trues(left, right): - l, r = toseq(left), toseq(right) - return (1 * asbool(l)) + (1 * asbool(r)) + l, r = toseq(left), toseq(right) + return (1 * asbool(l)) + (1 * asbool(r)) def quickname(v): - if isinstance(v, _Unfinished): - return v.name - else: - return str(v) + if isinstance(v, _Unfinished): + return v.name + else: + return str(v) def tpland(left, right): - res = _num_trues(left, right) == 2 - return _addname(res, None, "( " + quickname(left) + " and " - + quickname(right) + ")") + res = _num_trues(left, right) == 2 + return _addname(res, None, "( " + quickname(left) + " and " + + quickname(right) + ")") def tplor(left, right): - res = _num_trues(left, right) >= 1 - return _addname(res, None, "( " + quickname(left) + " or " - + quickname(right) + ")") + res = _num_trues(left, right) >= 1 + return _addname(res, None, "( " + quickname(left) + " or " + + quickname(right) + ")") diff --git a/RASP_support/Support.py b/RASP_support/Support.py index a6de676..2285438 100644 --- a/RASP_support/Support.py +++ b/RASP_support/Support.py @@ -6,25 +6,25 @@ class RASPError(Exception): - def __init__(self, *a): - super().__init__(" ".join([str(b) for b in a])) + def __init__(self, *a): + super().__init__(" ".join([str(b) for b in a])) class RASPTypeError(RASPError): - def __init__(self, *a): - super().__init__(*a) + def __init__(self, *a): + super().__init__(*a) def clean_val(num, digits=3): # taken from my helper functions - res = round(num, digits) - if digits == 0: - res = int(res) - return res + res = round(num, digits) + if digits == 0: + res = int(res) + return res class SupportException(Exception): - def __init__(self, m): - Exception.__init__(self, m) + def __init__(self, m): + Exception.__init__(self, m) TBANNED = "banned" @@ -33,248 +33,251 @@ def __init__(self, m): NUMTYPES = [TNAME[int], TNAME[float]] sorted_typenames_list = sorted(list(TNAME.values())) legal_types_list_string = ", ".join( - sorted_typenames_list[:-1])+" or "+sorted_typenames_list[-1] + sorted_typenames_list[:-1]) + " or " + sorted_typenames_list[-1] def is_in_types(v, tlist): - for t in tlist: - if isinstance(v, t): - return True - return False + for t in tlist: + if isinstance(v, t): + return True + return False def lazy_type_check(vals): - legal_val_types = [str, bool, int, float] - number_types = [int, float] + legal_val_types = [str, bool, int, float] + number_types = [int, float] - # all vals are same, legal, type: - for t in legal_val_types: - b = [isinstance(v, t) for v in vals] - if False not in b: - return TNAME[t] + # all vals are same, legal, type: + for t in legal_val_types: + b = [isinstance(v, t) for v in vals] + if False not in b: + return TNAME[t] - # allow vals to also be mixed integers and ints, treat those as floats - # (but don't actually change the ints to floats, want neat printouts) - b = [is_in_types(v, number_types) for v in vals] - if False not in b: - return TNAME[float] + # allow vals to also be mixed integers and ints, treat those as floats + # (but don't actually change the ints to floats, want neat printouts) + b = [is_in_types(v, number_types) for v in vals] + if False not in b: + return TNAME[float] - # from here it's all bad, but lets have some clear error messages - b = [is_in_types(v, legal_val_types) for v in vals] - if False not in b: - return TMISMATCHED # all legal types, but mismatched - else: - return TBANNED + # from here it's all bad, but lets have some clear error messages + b = [is_in_types(v, legal_val_types) for v in vals] + if False not in b: + return TMISMATCHED # all legal types, but mismatched + else: + return TBANNED class Sequence: - def __init__(self, vals): - self.type = lazy_type_check(vals) - if self.type == TMISMATCHED: - raise RASPTypeError( - "attempted to create sequence with vals of different types:" - + f"\n\t\t {vals}") - if self.type == TBANNED: - raise RASPTypeError( - "attempted to create sequence with illegal val types " - + f"(vals must be {legal_types_list_string}):\n\t\t {vals}") - self._vals = vals - - def __str__(self): - # return "Sequence"+str([small_str(v) for v in self._vals]) - if (len(set(self._vals)) == 1) and (len(self._vals) > 1): - res = "["+small_str(self._vals[0])+"]*"+str(len(self._vals)) - else: - res = "["+", ".join(small_str(v) for v in self._vals)+"]" - return colored(res, values_color) + \ - colored(" ("+self.type+"s)", general_color) - - def __repr__(self): - return str(self) - - def __len__(self): - return len(self._vals) - - def get_vals(self): - return deepcopy(self._vals) + def __init__(self, vals): + self.type = lazy_type_check(vals) + if self.type == TMISMATCHED: + raise RASPTypeError( + "attempted to create sequence with vals of different types:" + + f"\n\t\t {vals}") + if self.type == TBANNED: + raise RASPTypeError( + "attempted to create sequence with illegal val types " + + f"(vals must be {legal_types_list_string}):\n\t\t {vals}") + self._vals = vals + + def __str__(self): + # return "Sequence"+str([small_str(v) for v in self._vals]) + if (len(set(self._vals)) == 1) and (len(self._vals) > 1): + res = "[" + small_str(self._vals[0]) + "]*" + str(len(self._vals)) + else: + res = "[" + ", ".join(small_str(v) for v in self._vals) + "]" + return colored(res, values_color) + \ + colored(" (" + self.type + "s)", general_color) + + def __repr__(self): + return str(self) + + def __len__(self): + return len(self._vals) + + def get_vals(self): + return deepcopy(self._vals) def dims_match(seqs, expected_dim): - return False not in [expected_dim == len(seq) for seq in seqs] + return False not in [expected_dim == len(seq) for seq in seqs] class Select: - def __init__(self, n, q_vars, k_vars, f): - self.n = n - self.makeselect(q_vars, k_vars, f) - self.niceprint = None - - def get_vals(self): - if self.select is None: - self.makeselect() - return deepcopy(self.select) - - def makeselect(self, q_vars=None, k_vars=None, f=None): - if None is q_vars: - assert (None is k_vars) and (None is f) - q_vars = (Sequence(self.target_index),) - k_vars = (Sequence(list(range(self.n))),) - def f(t, i): return t == i - self.select = {i: [f(*get(q_vars, i), *get(k_vars, j)) - for j in range(self.n)] - for i in range(self.n)} # outputs of f should be - # True or False. j goes along input dim, i along output - - def __str__(self): - self.get_vals() - if None is self.niceprint: - d = {i: list(map(int, self.select[i])) for i in self.select} - self.niceprint = str(self.niceprint) - if len(str(d)) > 40: - starter = "\n" - self.niceprint = pprint.pformat(d) - else: - starter = "" - self.niceprint = str(d) - self.niceprint = starter + self.niceprint - return self.niceprint - - def __repr__(self): - return str(self) + def __init__(self, n, q_vars, k_vars, f): + self.n = n + self.makeselect(q_vars, k_vars, f) + self.niceprint = None + + def get_vals(self): + if self.select is None: + self.makeselect() + return deepcopy(self.select) + + def makeselect(self, q_vars=None, k_vars=None, f=None): + if None is q_vars: + assert (None is k_vars) and (None is f) + q_vars = (Sequence(self.target_index),) + k_vars = (Sequence(list(range(self.n))),) + + def f(t, i): + return t == i + + self.select = {i: [f(*get(q_vars, i), *get(k_vars, j)) + for j in range(self.n)] + for i in range(self.n)} # outputs of f should be + # True or False. j goes along input dim, i along output + + def __str__(self): + self.get_vals() + if None is self.niceprint: + d = {i: list(map(int, self.select[i])) for i in self.select} + self.niceprint = str(self.niceprint) + if len(str(d)) > 40: + starter = "\n" + self.niceprint = pprint.pformat(d) + else: + starter = "" + self.niceprint = str(d) + self.niceprint = starter + self.niceprint + return self.niceprint + + def __repr__(self): + return str(self) def select(n, q_vars, k_vars, f): - return Select(n, q_vars, k_vars, f) + return Select(n, q_vars, k_vars, f) # applying selects or feedforward (map) def aggregate(select, k_vars, func, default=None): - return to_sequences(apply_average_select(select, k_vars, func, default)) + return to_sequences(apply_average_select(select, k_vars, func, default)) def to_sequences(results_by_index): - def totup(r): - if not isinstance(r, tuple): - return (r,) - return r - # convert scalar results to tuples of length 1 - results_by_index = list(map(totup, results_by_index)) - # one list (sequence) per output value - results_by_output_val = list(zip(*results_by_index)) - res = tuple(map(Sequence, results_by_output_val)) - if len(res) == 1: - return res[0] - else: - return res + def totup(r): + if not isinstance(r, tuple): + return (r,) + return r + # convert scalar results to tuples of length 1 + results_by_index = list(map(totup, results_by_index)) + # one list (sequence) per output value + results_by_output_val = list(zip(*results_by_index)) + res = tuple(map(Sequence, results_by_output_val)) + if len(res) == 1: + return res[0] + else: + return res def zipmap(n, k_vars, func): - # assert len(k_vars) >= 1, "dont make a whole sequence for a plain constant - # you already know the value of.." - results_by_index = [func(*get(k_vars, i)) for i in range(n)] - return to_sequences(results_by_index) + # assert len(k_vars) >= 1, "dont make a whole sequence for a plain constant + # you already know the value of.." + results_by_index = [func(*get(k_vars, i)) for i in range(n)] + return to_sequences(results_by_index) def verify_default_size(default, num_output_vars): - assert num_output_vars > 0 - if num_output_vars == 1: - errnote = "aggregates on functions with single output should have" \ - + " scalar default" - assert not isinstance(default, tuple), errnote - elif num_output_vars > 1: - errnote = "for function with >1 output values, default should be" \ - + " tuple of default values, of equal length to passed" \ - + " function's output values (for function with single output" \ - + " value, default should be single value too)" - check = isinstance(default, tuple) and len(default) == num_output_vars - assert check, errnote + assert num_output_vars > 0 + if num_output_vars == 1: + errnote = "aggregates on functions with single output should have" \ + + " scalar default" + assert not isinstance(default, tuple), errnote + elif num_output_vars > 1: + errnote = "for function with >1 output values, default should be" \ + + " tuple of default values, of equal length to passed" \ + + " function's output values (for function with single output" \ + + " value, default should be single value too)" + check = isinstance(default, tuple) and len(default) == num_output_vars + assert check, errnote def apply_average_select(select, k_vars, func, default=0): - def apply_func_to_each_index(): - # kvs is list [by index] of lists [by varname] of values - kvs = [get(k_vars, i) for i in list(range(select.n))] - candidate_i = [func(*kvi) for kvi in kvs] # candidate output per index - if num_output_vars > 1: - candidates_by_varname = list(zip(*candidate_i)) - else: - # expect tuples of values for conversions in return_sequences - candidates_by_varname = (candidate_i,) - return candidates_by_varname - - def prep_default(default, num_output_vars): - if None is default: - default = 0 - # output of average is always floats, so will be converting all - # to floats here else we'll fail the lazy type check in the - # Sequences. (and float(None) doesn't 'compile' ) - # TODO: maybe just lose the lazy type check? - if not isinstance(default, tuple) and (num_output_vars > 1): - default = tuple([default]*num_output_vars) - # *specifically* in apply_average, where values have to be floats, - # allow default to be single val, - # that will be repeated for all wanted outputs - verify_default_size(default, num_output_vars) - if not isinstance(default, tuple): - # specifically with how we're going to do things here in the - # average aggregate, will help to actually have the outputs get - # passed around as tuples, even if they're scalars really. - # but do this after the size check for the scalar one so it doesn't - # get filled with weird ifs... this tupled scalar thing is only a - # convenience in this implementation in this here function - default = (default,) - return default - - def apply_and_average_single_index(outputs_by_varname, index, - index_scores, num_output_vars, default): - def mean(scores, vals): - n = scores.count(True) # already >0 by earlier - if n == 1: - return vals[scores.index(True)] - # else # n>1 - if not (lazy_type_check(vals) in NUMTYPES): - raise Exception( - "asked to average multiple values, but they are " - + "non-numbers: " + str(vals)) - return sum([v for s, v in zip(scores, vals) if s])*1.0/n - - num_influencers = index_scores.count(True) - if num_influencers == 0: - return default - else: - # return_sequences expects multiple outputs to be in tuple form - return tuple(mean(index_scores, o_by_i) - for o_by_i in outputs_by_varname) - num_output_vars = get_num_outputs(func(*get(k_vars, 0))) - candidates_by_varname = apply_func_to_each_index() - default = prep_default(default, num_output_vars) - means_per_index = [apply_and_average_single_index(candidates_by_varname, - i, select.select[i], - num_output_vars, default) - for i in range(select.n)] - # list (per index) of all the new variable values (per varname) - return means_per_index + def apply_func_to_each_index(): + # kvs is list [by index] of lists [by varname] of values + kvs = [get(k_vars, i) for i in list(range(select.n))] + candidate_i = [func(*kvi) for kvi in kvs] # candidate output per index + if num_output_vars > 1: + candidates_by_varname = list(zip(*candidate_i)) + else: + # expect tuples of values for conversions in return_sequences + candidates_by_varname = (candidate_i,) + return candidates_by_varname + + def prep_default(default, num_output_vars): + if None is default: + default = 0 + # output of average is always floats, so will be converting all + # to floats here else we'll fail the lazy type check in the + # Sequences. (and float(None) doesn't 'compile' ) + # TODO: maybe just lose the lazy type check? + if not isinstance(default, tuple) and (num_output_vars > 1): + default = tuple([default] * num_output_vars) + # *specifically* in apply_average, where values have to be floats, + # allow default to be single val, + # that will be repeated for all wanted outputs + verify_default_size(default, num_output_vars) + if not isinstance(default, tuple): + # specifically with how we're going to do things here in the + # average aggregate, will help to actually have the outputs get + # passed around as tuples, even if they're scalars really. + # but do this after the size check for the scalar one so it doesn't + # get filled with weird ifs... this tupled scalar thing is only a + # convenience in this implementation in this here function + default = (default,) + return default + + def apply_and_average_single_index(outputs_by_varname, index, + index_scores, num_output_vars, default): + def mean(scores, vals): + n = scores.count(True) # already >0 by earlier + if n == 1: + return vals[scores.index(True)] + # else # n>1 + if not (lazy_type_check(vals) in NUMTYPES): + raise Exception( + "asked to average multiple values, but they are " + + "non-numbers: " + str(vals)) + return sum([v for s, v in zip(scores, vals) if s]) * 1.0 / n + + num_influencers = index_scores.count(True) + if num_influencers == 0: + return default + else: + # return_sequences expects multiple outputs to be in tuple form + return tuple(mean(index_scores, o_by_i) + for o_by_i in outputs_by_varname) + num_output_vars = get_num_outputs(func(*get(k_vars, 0))) + candidates_by_varname = apply_func_to_each_index() + default = prep_default(default, num_output_vars) + means_per_index = [apply_and_average_single_index(candidates_by_varname, + i, select.select[i], + num_output_vars, default) + for i in range(select.n)] + # list (per index) of all the new variable values (per varname) + return means_per_index # user's responsibility to give functions that always have same number of # outputs def get_num_outputs(dummy_out): - if isinstance(dummy_out, tuple): - return len(dummy_out) - return 1 + if isinstance(dummy_out, tuple): + return len(dummy_out) + return 1 def small_str(v): - if isinstance(v, float): - return str(clean_val(v, 3)) - if isinstance(v, bool): - return "T" if v else "F" - return str(v) + if isinstance(v, float): + return str(clean_val(v, 3)) + if isinstance(v, bool): + return "T" if v else "F" + return str(v) def get(vars_list, index): # index should be within range to access - # v._vals and if not absolutely should raise an error, as it will here - # by the attempted access - res = deepcopy([v._vals[index] for v in vars_list]) - return res + # v._vals and if not absolutely should raise an error, as it will here + # by the attempted access + res = deepcopy([v._vals[index] for v in vars_list]) + return res diff --git a/RASP_support/analyse.py b/RASP_support/analyse.py index 1a023fe..be9b187 100644 --- a/RASP_support/analyse.py +++ b/RASP_support/analyse.py @@ -1,22 +1,22 @@ from .FunctionalSupport import Unfinished, UnfinishedSequence, \ - UnfinishedSelect, guarded_contains, guarded_compare, zipmap + UnfinishedSelect, guarded_contains, guarded_compare, zipmap from collections import defaultdict, Counter from copy import copy def UnfinishedFunc(f): - setattr(Unfinished, f.__name__, f) + setattr(Unfinished, f.__name__, f) @UnfinishedFunc def get_parent_sequences(self): - # for UnfinishedSequences, this should get just the tuple of sequences the - # aggregate is applied to, and I think in order (as the parents will only - # be a select and a sequencestuple, and the seqs in the sequencestuple will - # be added in order and the select will be removed in this function) + # for UnfinishedSequences, this should get just the tuple of sequences the + # aggregate is applied to, and I think in order (as the parents will only + # be a select and a sequencestuple, and the seqs in the sequencestuple will + # be added in order and the select will be removed in this function) - # i.e. drop the selects - return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)] + # i.e. drop the selects + return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)] Unfinished._full_seq_parents = None @@ -24,95 +24,95 @@ def get_parent_sequences(self): @UnfinishedFunc def get_full_seq_parents(self): - if self._full_seq_parents is None: - self._full_seq_parents = [u for u in self.get_full_parents() - if isinstance(u, UnfinishedSequence)] - return copy(self._full_seq_parents) + if self._full_seq_parents is None: + self._full_seq_parents = [u for u in self.get_full_parents() + if isinstance(u, UnfinishedSequence)] + return copy(self._full_seq_parents) @UnfinishedFunc def get_parent_select(self): - if not hasattr(self, "parent_select"): - real_parents = self.get_parents() - self.parent_select = next((s for s in real_parents if - isinstance(s, UnfinishedSelect)), None) - return self.parent_select + if not hasattr(self, "parent_select"): + real_parents = self.get_parents() + self.parent_select = next((s for s in real_parents if + isinstance(s, UnfinishedSelect)), None) + return self.parent_select @UnfinishedFunc def set_analysis_parent_select(self, options): - # doesn't really need to be a function but feels clearer visually to have - # it out here so i can see this variable is being registered to the - # unfinisheds - if None is self.parent_select: - self.analysis_parent_select = self.parent_select - else: - getps = (ps for ps in options if - ps.compare_string == self.get_parent_select().compare_string) - self.analysis_parent_select = next(getps, None) - errnote = "parent options given to seq: " + self.name + " did not " \ - + "include anything equivalent to actual seq's parent select (" \ - + self.get_parent_select().compare_string + ")" - assert self.analysis_parent_select is not None, errnote + # doesn't really need to be a function but feels clearer visually to have + # it out here so i can see this variable is being registered to the + # unfinisheds + if None is self.parent_select: + self.analysis_parent_select = self.parent_select + else: + getps = (ps for ps in options if + ps.compare_string == self.get_parent_select().compare_string) + self.analysis_parent_select = next(getps, None) + errnote = "parent options given to seq: " + self.name + " did not " \ + + "include anything equivalent to actual seq's parent select (" \ + + self.get_parent_select().compare_string + ")" + assert self.analysis_parent_select is not None, errnote def squeeze_selects(selects): - compstrs = set([s.compare_string for s in selects]) - if len(compstrs) == len(selects): - return selects - return [next(s for s in selects if s.compare_string == cs) - for cs in compstrs] + compstrs = set([s.compare_string for s in selects]) + if len(compstrs) == len(selects): + return selects + return [next(s for s in selects if s.compare_string == cs) + for cs in compstrs] @UnfinishedFunc def schedule(self, scheduler='best', remove_minors=False): - # recall attentions can be created on level 1 but still generate seqs on - # level 3 etc hence width is number of *seqs* with different attentions per - # level. - def choose_scheduler(scheduler): - if scheduler == 'best': - return 'greedy' - # TODO: implement lastminute, maybe others, and choose narrowest - # result of all options - return scheduler - scheduler = choose_scheduler(scheduler) - seq_layers = self.greedy_seq_scheduler() if scheduler == 'greedy' \ - else self.lastminute_seq_scheduler() - - if remove_minors: - for i in seq_layers: - seq_layers[i] = [seq for seq in seq_layers[i] if not seq.is_minor] - - def get_seqs_selects(seqs): - # all the selects needed to compute a set of seqs - all_selects = set(seq.get_parent_select() for seq in seqs) - # some of the seqs may not have parent matches, - # eg, indices. these will return None, which we don't want to count - all_selects -= set([None]) - return squeeze_selects(all_selects) # squeeze identical parents - - layer_selects = {i: get_seqs_selects(seq_layers[i]) for i in seq_layers} - - # mark remaining parent select after squeeze - for i in seq_layers: - for seq in seq_layers[i]: - seq.set_analysis_parent_select(layer_selects[i]) - - return seq_layers, layer_selects + # recall attentions can be created on level 1 but still generate seqs on + # level 3 etc hence width is number of *seqs* with different attentions per + # level. + def choose_scheduler(scheduler): + if scheduler == 'best': + return 'greedy' + # TODO: implement lastminute, maybe others, and choose narrowest + # result of all options + return scheduler + scheduler = choose_scheduler(scheduler) + seq_layers = self.greedy_seq_scheduler() if scheduler == 'greedy' \ + else self.lastminute_seq_scheduler() + + if remove_minors: + for i in seq_layers: + seq_layers[i] = [seq for seq in seq_layers[i] if not seq.is_minor] + + def get_seqs_selects(seqs): + # all the selects needed to compute a set of seqs + all_selects = set(seq.get_parent_select() for seq in seqs) + # some of the seqs may not have parent matches, + # eg, indices. these will return None, which we don't want to count + all_selects -= set([None]) + return squeeze_selects(all_selects) # squeeze identical parents + + layer_selects = {i: get_seqs_selects(seq_layers[i]) for i in seq_layers} + + # mark remaining parent select after squeeze + for i in seq_layers: + for seq in seq_layers[i]: + seq.set_analysis_parent_select(layer_selects[i]) + + return seq_layers, layer_selects @UnfinishedFunc def greedy_seq_scheduler(self): - all_seqs = sorted(self.get_full_seq_parents(), - key=lambda seq: seq.creation_order_id) - # sorting in order of creation automatically sorts by order of in-layer - # dependencies (i.e. things got through feedforwards), makes prints clearer - # and eventually is helpful for drawcompflow - levels = defaultdict(lambda: []) - for seq in all_seqs: - # schedule all seqs as early as possible - levels[seq.min_poss_depth].append(seq) - return levels + all_seqs = sorted(self.get_full_seq_parents(), + key=lambda seq: seq.creation_order_id) + # sorting in order of creation automatically sorts by order of in-layer + # dependencies (i.e. things got through feedforwards), makes prints clearer + # and eventually is helpful for drawcompflow + levels = defaultdict(lambda: []) + for seq in all_seqs: + # schedule all seqs as early as possible + levels[seq.min_poss_depth].append(seq) + return levels Unfinished.max_poss_depth_for_seq = (None, None) @@ -120,138 +120,138 @@ def greedy_seq_scheduler(self): @UnfinishedFunc def lastminute_for_seq(self, seq): - raise NotImplementedError + raise NotImplementedError @UnfinishedFunc def lastminute_seq_scheduler(self): - all_seqs = self.get_full_seq_parents() + all_seqs = self.get_full_seq_parents() @UnfinishedFunc def typestr(self): - if isinstance(self, UnfinishedSelect): - return "select" - elif isinstance(self, UnfinishedSequence): - return "seq" - else: - return "internal" + if isinstance(self, UnfinishedSelect): + return "select" + elif isinstance(self, UnfinishedSequence): + return "seq" + else: + return "internal" @UnfinishedFunc def width_and_depth(self, scheduler='greedy', loud=True, print_tree_too=False, - remove_minors=False): - seq_layers, layer_selects = self.schedule( - scheduler=scheduler, remove_minors=remove_minors) - widths = {i: len(layer_selects[i]) for i in layer_selects} - n_layers = max(seq_layers.keys()) - max_width = max(widths[i] for i in widths) - if loud: - print("analysing unfinished", self.typestr()+":", self.name) - print("using scheduler:", scheduler) - print("num layers:", n_layers, "max width:", max_width) - print("width per layer:") - print("\n".join(str(i)+"\t: "+str(widths[i]) - for i in range(1, n_layers+1))) - # start from 1 to skip layer 0, which has width 0 - # and is just the inputs (tokens and indices) - if print_tree_too: - def print_layer(i, d): - print(i, "\t:", ", ".join(seq.name for seq in d[i])) - print("==== seqs at each layer: ====") - [print_layer(i, seq_layers) for i in range(1, n_layers+1)] - print("==== selects at each layer: ====") - [print_layer(i, layer_selects) for i in range(1, n_layers+1)] - return n_layers, max_width, widths + remove_minors=False): + seq_layers, layer_selects = self.schedule( + scheduler=scheduler, remove_minors=remove_minors) + widths = {i: len(layer_selects[i]) for i in layer_selects} + n_layers = max(seq_layers.keys()) + max_width = max(widths[i] for i in widths) + if loud: + print("analysing unfinished", self.typestr() + ":", self.name) + print("using scheduler:", scheduler) + print("num layers:", n_layers, "max width:", max_width) + print("width per layer:") + print("\n".join(str(i) + "\t: " + str(widths[i]) + for i in range(1, n_layers + 1))) + # start from 1 to skip layer 0, which has width 0 + # and is just the inputs (tokens and indices) + if print_tree_too: + def print_layer(i, d): + print(i, "\t:", ", ".join(seq.name for seq in d[i])) + print("==== seqs at each layer: ====") + [print_layer(i, seq_layers) for i in range(1, n_layers + 1)] + print("==== selects at each layer: ====") + [print_layer(i, layer_selects) for i in range(1, n_layers + 1)] + return n_layers, max_width, widths @UnfinishedFunc def schedule_comp_depth(self, d): - self.scheduled_comp_depth = d + self.scheduled_comp_depth = d @UnfinishedFunc def get_all_ancestor_heads_and_ffs(self, remove_minors=False): - class Head: - def __init__(self, select, sequences, comp_depth): - self.comp_depth = comp_depth - self.name = str([m.name for m in sequences]) - self.sequences = sequences - self.select = select - seq_layers, layer_selects = self.schedule( - 'best', remove_minors=remove_minors) - - all_ffs = self.get_full_seq_parents() - if len(all_ffs) > 1: - # filter out non-ffs in the non-trivial case - all_ffs = [m for m in all_ffs if m.from_zipmap] - if remove_minors: - all_ffs = [ff for ff in all_ffs if not ff.is_minor] - - for i in seq_layers: - for m in seq_layers[i]: - if guarded_contains(all_ffs, m): - # mark comp depths of the ffs... drawcompflow wants to know - m.schedule_comp_depth(i) - - heads = [] - for i in layer_selects: - for s in layer_selects[i]: - seqs = [m for m in seq_layers[i] if m.analysis_parent_select == s] - heads.append(Head(s, seqs, i)) - - return heads, all_ffs + class Head: + def __init__(self, select, sequences, comp_depth): + self.comp_depth = comp_depth + self.name = str([m.name for m in sequences]) + self.sequences = sequences + self.select = select + seq_layers, layer_selects = self.schedule( + 'best', remove_minors=remove_minors) + + all_ffs = self.get_full_seq_parents() + if len(all_ffs) > 1: + # filter out non-ffs in the non-trivial case + all_ffs = [m for m in all_ffs if m.from_zipmap] + if remove_minors: + all_ffs = [ff for ff in all_ffs if not ff.is_minor] + + for i in seq_layers: + for m in seq_layers[i]: + if guarded_contains(all_ffs, m): + # mark comp depths of the ffs... drawcompflow wants to know + m.schedule_comp_depth(i) + + heads = [] + for i in layer_selects: + for s in layer_selects[i]: + seqs = [m for m in seq_layers[i] if m.analysis_parent_select == s] + heads.append(Head(s, seqs, i)) + + return heads, all_ffs @UnfinishedFunc def set_display_name(self, display_name): - self.display_name = display_name - # again just making it more visible??? that there's an attribute being set - # somewhere + self.display_name = display_name + # again just making it more visible??? that there's an attribute being set + # somewhere @UnfinishedFunc def make_display_names_for_all_parents(self, skip_minors=False): - all_unfs = self.get_full_parents() - all_seqs = [u for u in set(all_unfs) if isinstance(u, UnfinishedSequence)] - all_selects = [u for u in set(all_unfs) if isinstance(u, UnfinishedSelect)] - if skip_minors: - num_orig = len(all_seqs) - all_seqs = [seq for seq in all_seqs if not seq.is_minor] - name_counts = Counter([m.name for m in all_seqs]) - name_suff = Counter() - for m in sorted(all_seqs+all_selects, key=lambda u: u.creation_order_id): - # yes, even the non-seqs need display names, albeit for now only worry - # about repeats in the seqs and sort by creation order to get name - # suffixes with chronological (and so non-confusing) order - if name_counts[m.name] > 1: - m.set_display_name(m.name+"_"+str(name_suff[m.name])) - name_suff[m.name] += 1 - - else: - m.set_display_name(m.name) + all_unfs = self.get_full_parents() + all_seqs = [u for u in set(all_unfs) if isinstance(u, UnfinishedSequence)] + all_selects = [u for u in set(all_unfs) if isinstance(u, UnfinishedSelect)] + if skip_minors: + num_orig = len(all_seqs) + all_seqs = [seq for seq in all_seqs if not seq.is_minor] + name_counts = Counter([m.name for m in all_seqs]) + name_suff = Counter() + for m in sorted(all_seqs + all_selects, key=lambda u: u.creation_order_id): + # yes, even the non-seqs need display names, albeit for now only worry + # about repeats in the seqs and sort by creation order to get name + # suffixes with chronological (and so non-confusing) order + if name_counts[m.name] > 1: + m.set_display_name(m.name + "_" + str(name_suff[m.name])) + name_suff[m.name] += 1 + + else: + m.set_display_name(m.name) @UnfinishedFunc def note_if_seeker(self): - if not isinstance(self, UnfinishedSequence): - return + if not isinstance(self, UnfinishedSequence): + return - if (not self.get_parent_sequences()) \ - and (self.get_parent_select() is not None): - # no parent sequences, but yes parent select: this value is a function - # of only its parent select, i.e., a seeker (marks whether select found - # something or not) - self.is_seeker = True - self.seeker_flag = self.elementwise_function() - self.seeker_default = self._default - else: - self.is_seeker = False + if (not self.get_parent_sequences()) and \ + (self.get_parent_select() is not None): + # no parent sequences, but yes parent select: this value is a function + # of only its parent select, i.e., a seeker (marks whether select found + # something or not) + self.is_seeker = True + self.seeker_flag = self.elementwise_function() + self.seeker_default = self._default + else: + self.is_seeker = False @UnfinishedFunc def mark_all_ancestor_seekers(self): - [u.note_if_seeker() for u in self.get_full_parents()] + [u.note_if_seeker() for u in self.get_full_parents()] Unfinished._full_descendants_for_seq = (None, None) @@ -259,47 +259,47 @@ def mark_all_ancestor_seekers(self): @UnfinishedFunc def descendants_towards_seq(self, seq): - if not guarded_compare(self._full_descendants_for_seq[0], seq): + if not guarded_compare(self._full_descendants_for_seq[0], seq): - relevant = seq.get_full_parents() - res = [r for r in relevant if guarded_contains(r.get_parents(), self)] + relevant = seq.get_full_parents() + res = [r for r in relevant if guarded_contains(r.get_parents(), self)] - self._full_descendants_for_seq = (seq, res) - return self._full_descendants_for_seq[1] + self._full_descendants_for_seq = (seq, res) + return self._full_descendants_for_seq[1] @UnfinishedFunc def is_minor_comp_towards_seq(self, seq): - if not isinstance(self, UnfinishedSequence): - return False # selects are always important - if self.never_display: # priority: never over always - return True - if self.always_display: - if self.is_constant(): - print("displaying constant:", self.name) - return False - if self.is_constant(): # e.g. 1 or "a" etc, just stuff created around - # constants by REPL behind the scenes - return True - children = self.descendants_towards_seq(seq) - if len(children) > 1: - return False # this sequence was used twice -> must have been actually - # named as a real variable in the code (and not part of some bunch of - # operators) -> make it visible in the comp flow too - if len(children) == 0: - # if it's the seq itself then clearly we're very interested in it. if - # it has no children and isnt the seq then we're checking out a weird - # dangly unused leaf, we shouldn't reach such a scenario through any of - # functions we'll be using to call this one, but might as well make - # this function complete just in case we forget - return not guarded_compare(self, seq) - child = children[0] - if isinstance(child, UnfinishedSelect): - return False # this thing feeds directly into a select, lets make it - # visible - # obtained through zipmap and feeds directly into another zipmap: minor - # operation as part of something more complicated - return (child.from_zipmap and self.from_zipmap) + if not isinstance(self, UnfinishedSequence): + return False # selects are always important + if self.never_display: # priority: never over always + return True + if self.always_display: + if self.is_constant(): + print("displaying constant:", self.name) + return False + if self.is_constant(): # e.g. 1 or "a" etc, just stuff created around + # constants by REPL behind the scenes + return True + children = self.descendants_towards_seq(seq) + if len(children) > 1: + return False # this sequence was used twice -> must have been actually + # named as a real variable in the code (and not part of some bunch of + # operators) -> make it visible in the comp flow too + if len(children) == 0: + # if it's the seq itself then clearly we're very interested in it. if + # it has no children and isnt the seq then we're checking out a weird + # dangly unused leaf, we shouldn't reach such a scenario through any of + # functions we'll be using to call this one, but might as well make + # this function complete just in case we forget + return not guarded_compare(self, seq) + child = children[0] + if isinstance(child, UnfinishedSelect): + return False # this thing feeds directly into a select, lets make it + # visible + # obtained through zipmap and feeds directly into another zipmap: minor + # operation as part of something more complicated + return (child.from_zipmap and self.from_zipmap) Unfinished.is_minor = False @@ -308,51 +308,51 @@ def is_minor_comp_towards_seq(self, seq): @UnfinishedFunc # another func just to be very explicit about an attribute that's getting set def set_minor_for_seq(self, seq): - self.is_minor = self.is_minor_comp_towards_seq(seq) + self.is_minor = self.is_minor_comp_towards_seq(seq) @UnfinishedFunc def mark_all_minor_ancestors(self): - all_ancestors = self.get_full_parents() - for a in all_ancestors: - a.set_minor_for_seq(self) + all_ancestors = self.get_full_parents() + for a in all_ancestors: + a.set_minor_for_seq(self) @UnfinishedFunc def get_nonminor_parents(self): # assumes have already marked the minor - # parents according to current interests. - # otherwise, may remain marked according to a different seq, or possibly - # all on default value (none are minor, all are important) - potentials = self.get_parents() - nonminors = [] - while potentials: - p = potentials.pop() - if not p.is_minor: - nonminors.append(p) - else: - potentials.update(p.get_parents()) - return set(nonminors) + # parents according to current interests. + # otherwise, may remain marked according to a different seq, or possibly + # all on default value (none are minor, all are important) + potentials = self.get_parents() + nonminors = [] + while potentials: + p = potentials.pop() + if not p.is_minor: + nonminors.append(p) + else: + potentials.update(p.get_parents()) + return set(nonminors) @UnfinishedFunc def get_nonminor_parent_sequences(self): - return [p for p in self.get_nonminor_parents() - if isinstance(p, UnfinishedSequence)] + return [p for p in self.get_nonminor_parents() + if isinstance(p, UnfinishedSequence)] @UnfinishedFunc # gets both minor and nonminor sequences def get_immediate_parent_sequences(self): - return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)] + return [p for p in self.get_parents() if isinstance(p, UnfinishedSequence)] @UnfinishedFunc def pre_aggregate_comp(seq): - vvars = seq.get_parent_sequences() - vreal = zipmap(vvars, seq.elementwise_function) - if isinstance(vreal, tuple): # equivalently, if seq.output_index >= 0: - vreal = vreal[seq.output_index] - return vreal + vvars = seq.get_parent_sequences() + vreal = zipmap(vvars, seq.elementwise_function) + if isinstance(vreal, tuple): # equivalently, if seq.output_index >= 0: + vreal = vreal[seq.output_index] + return vreal dummyimport = None diff --git a/RASP_support/make_operators.py b/RASP_support/make_operators.py index a015308..f517d3f 100644 --- a/RASP_support/make_operators.py +++ b/RASP_support/make_operators.py @@ -3,149 +3,151 @@ # make them fully named functions instead of lambdas, even though # it's more lines, because the debug prints are so much clearer # this way + + def add_ops(Class, apply_unary_op, apply_binary_op): - def addsetname(f, opname, rev): - def f_with_setname(*a): - - assert len(a) in [1, 2] - if len(a) == 2: - a0, a1 = a if not rev else (a[1], a[0]) - name0 = a0.name if hasattr(a0, "name") else str(a0) - name1 = a1.name if hasattr(a1, "name") else str(a1) - # a0/a1 might not be a seq, just having an op on it with a seq. - name = name0 + " " + opname + " " + name1 - else: # len(a)==1 - name = opname + " " + a[0].name - # probably going to be composed with more ops, so... - name = "( " + name + " )" - return f(*a).setname(name).allow_suppressing_display() - # seqs created as parts of long sequences of operators may be - # suppressed in display, the final name of the whole composition - # will be sufficiently informative. Have to set always_display to - # false *after* the setname, because setname marks always_display - # as True (under assumption it is normally being called by the - # user, who must clearly be naming some variable they care about) - return f_with_setname - - def listop(f, listing_name): - setattr(Class, listing_name, f) - - def addop(opname, rev=False): - return lambda f: listop(addsetname(f, opname, rev), f.__name__) - - @addop("==") - def __eq__(self, other): - return apply_binary_op(self, other, lambda a, b: a == b) - - @addop("!=") - def __ne__(self, other): - return apply_binary_op(self, other, lambda a, b: a != b) - - @addop("<") - def __lt__(self, other): - return apply_binary_op(self, other, lambda a, b: a < b) - - @addop(">") - def __gt__(self, other): - return apply_binary_op(self, other, lambda a, b: a > b) - - @addop("<=") - def __le__(self, other): - return apply_binary_op(self, other, lambda a, b: a <= b) - - @addop(">=") - def __ge__(self, other): - return apply_binary_op(self, other, lambda a, b: a >= b) - - @addop("+") - def __add__(self, other): - return apply_binary_op(self, other, lambda a, b: a+b) - - @addop("+", True) - def __radd__(self, other): - return apply_binary_op(self, other, lambda a, b: b+a) - - @addop("-") - def __sub__(self, other): - return apply_binary_op(self, other, lambda a, b: a-b) - - @addop("-", True) - def __rsub__(self, other): - return apply_binary_op(self, other, lambda a, b: b-a) - - @addop("*") - def __mul__(self, other): - return apply_binary_op(self, other, lambda a, b: a*b) - - @addop("*", True) - def __rmul__(self, other): - return apply_binary_op(self, other, lambda a, b: b*a) - - @addop("//") - def __floordiv__(self, other): - return apply_binary_op(self, other, lambda a, b: a//b) - - @addop("//", True) - def __rfloordiv__(self, other): - return apply_binary_op(self, other, lambda a, b: b//a) - - @addop("/") - def __truediv__(self, other): - return apply_binary_op(self, other, lambda a, b: a/b) - - @addop("/", True) - def __rtruediv__(self, other): - return apply_binary_op(self, other, lambda a, b: b/a) - - @addop("%") - def __mod__(self, other): - return apply_binary_op(self, other, lambda a, b: a % b) - - @addop("%", True) - def __rmod__(self, other): - return apply_binary_op(self, other, lambda a, b: b % a) - - @addop("divmod") - def __divmod__(self, other): - return apply_binary_op(self, other, lambda a, b: divmod(a, b)) - - @addop("divmod", True) - def __rdivmod__(self, other): - return apply_binary_op(self, other, lambda a, b: divmod(b, a)) - - @addop("pow") - def __pow__(self, other): - return apply_binary_op(self, other, lambda a, b: pow(a, b)) - - @addop("pow", True) - def __rpow__(self, other): - return apply_binary_op(self, other, lambda a, b: pow(b, a)) - - # skipping and, or, xor, which are bitwise and dont implement 'and' and - # 'or' but rather & and |. - # similarly skipping lshift, rshift cause who wants them. - # wish i had not, and, or primitives, but can accept that dont. - # if people really want to do 'not' they can do '==False' instead, can do a - # little macro for it in the other sugar file or whatever - - @addop("+") - def __pos__(self): - return apply_unary_op(self, lambda a: +a) - - @addop("-") - def __neg__(self): - return apply_unary_op(self, lambda a: -a) - - @addop("abs") - def __abs__(self): - return apply_unary_op(self, abs) - - @addop("round") - # not sure if python will get upset if round doesnt return an actual int - # tbh... will have to check. - def __round__(self): - return apply_unary_op(self, round) - - # defining floor, ceil, trunc showed up funny (green instead of blue), - # gonna go ahead and avoid + def addsetname(f, opname, rev): + def f_with_setname(*a): + + assert len(a) in [1, 2] + if len(a) == 2: + a0, a1 = a if not rev else (a[1], a[0]) + name0 = a0.name if hasattr(a0, "name") else str(a0) + name1 = a1.name if hasattr(a1, "name") else str(a1) + # a0/a1 might not be a seq, just having an op on it with a seq. + name = name0 + " " + opname + " " + name1 + else: # len(a)==1 + name = opname + " " + a[0].name + # probably going to be composed with more ops, so... + name = "( " + name + " )" + return f(*a).setname(name).allow_suppressing_display() + # seqs created as parts of long sequences of operators may be + # suppressed in display, the final name of the whole composition + # will be sufficiently informative. Have to set always_display to + # false *after* the setname, because setname marks always_display + # as True (under assumption it is normally being called by the + # user, who must clearly be naming some variable they care about) + return f_with_setname + + def listop(f, listing_name): + setattr(Class, listing_name, f) + + def addop(opname, rev=False): + return lambda f: listop(addsetname(f, opname, rev), f.__name__) + + @addop("==") + def __eq__(self, other): + return apply_binary_op(self, other, lambda a, b: a == b) + + @addop("!=") + def __ne__(self, other): + return apply_binary_op(self, other, lambda a, b: a != b) + + @addop("<") + def __lt__(self, other): + return apply_binary_op(self, other, lambda a, b: a < b) + + @addop(">") + def __gt__(self, other): + return apply_binary_op(self, other, lambda a, b: a > b) + + @addop("<=") + def __le__(self, other): + return apply_binary_op(self, other, lambda a, b: a <= b) + + @addop(">=") + def __ge__(self, other): + return apply_binary_op(self, other, lambda a, b: a >= b) + + @addop("+") + def __add__(self, other): + return apply_binary_op(self, other, lambda a, b: a + b) + + @addop("+", True) + def __radd__(self, other): + return apply_binary_op(self, other, lambda a, b: b + a) + + @addop("-") + def __sub__(self, other): + return apply_binary_op(self, other, lambda a, b: a - b) + + @addop("-", True) + def __rsub__(self, other): + return apply_binary_op(self, other, lambda a, b: b - a) + + @addop("*") + def __mul__(self, other): + return apply_binary_op(self, other, lambda a, b: a * b) + + @addop("*", True) + def __rmul__(self, other): + return apply_binary_op(self, other, lambda a, b: b * a) + + @addop("//") + def __floordiv__(self, other): + return apply_binary_op(self, other, lambda a, b: a // b) + + @addop("//", True) + def __rfloordiv__(self, other): + return apply_binary_op(self, other, lambda a, b: b // a) + + @addop("/") + def __truediv__(self, other): + return apply_binary_op(self, other, lambda a, b: a / b) + + @addop("/", True) + def __rtruediv__(self, other): + return apply_binary_op(self, other, lambda a, b: b / a) + + @addop("%") + def __mod__(self, other): + return apply_binary_op(self, other, lambda a, b: a % b) + + @addop("%", True) + def __rmod__(self, other): + return apply_binary_op(self, other, lambda a, b: b % a) + + @addop("divmod") + def __divmod__(self, other): + return apply_binary_op(self, other, lambda a, b: divmod(a, b)) + + @addop("divmod", True) + def __rdivmod__(self, other): + return apply_binary_op(self, other, lambda a, b: divmod(b, a)) + + @addop("pow") + def __pow__(self, other): + return apply_binary_op(self, other, lambda a, b: pow(a, b)) + + @addop("pow", True) + def __rpow__(self, other): + return apply_binary_op(self, other, lambda a, b: pow(b, a)) + + # skipping and, or, xor, which are bitwise and dont implement 'and' and + # 'or' but rather & and |. + # similarly skipping lshift, rshift cause who wants them. + # wish i had not, and, or primitives, but can accept that dont. + # if people really want to do 'not' they can do '==False' instead, can do a + # little macro for it in the other sugar file or whatever + + @addop("+") + def __pos__(self): + return apply_unary_op(self, lambda a: +a) + + @addop("-") + def __neg__(self): + return apply_unary_op(self, lambda a: -a) + + @addop("abs") + def __abs__(self): + return apply_unary_op(self, abs) + + @addop("round") + # not sure if python will get upset if round doesnt return an actual int + # tbh... will have to check. + def __round__(self): + return apply_unary_op(self, round) + + # defining floor, ceil, trunc showed up funny (green instead of blue), + # gonna go ahead and avoid diff --git a/tests/make_tgts.py b/tests/make_tgts.py index c460797..519b9cd 100644 --- a/tests/make_tgts.py +++ b/tests/make_tgts.py @@ -4,13 +4,13 @@ testpath = "tests" -inpath = testpath+"/in" -outpath = testpath+"/out" -tgtpath = testpath+"/tgt" -libtestspath = testpath+"/broken_libs" -libspath = libtestspath+"/lib" -libtgtspath = libtestspath+"/tgt" -liboutspath = libtestspath+"/out" +inpath = testpath + "/in" +outpath = testpath + "/out" +tgtpath = testpath + "/tgt" +libtestspath = testpath + "/broken_libs" +libspath = libtestspath + "/lib" +libtgtspath = libtestspath + "/tgt" +liboutspath = libtestspath + "/out" curr_path_marker = "[current]" @@ -20,76 +20,76 @@ def things_in_path(path): - if not os.path.exists(path): - return [] - return [p for p in os.listdir(path) if not p == ".DS_Store"] + if not os.path.exists(path): + return [] + return [p for p in os.listdir(path) if not p == ".DS_Store"] def joinpath(*a): - return "/".join(a) + return "/".join(a) for p in [tgtpath, libtgtspath]: - if not os.path.exists(p): - os.makedirs(p) + if not os.path.exists(p): + os.makedirs(p) all_names = things_in_path(inpath) def fix_file_paths(filename, curr_path_marker): - mypath = os.path.abspath(".") + mypath = os.path.abspath(".") - with open(filename, "r") as f: - filecontents = "".join(f) + with open(filename, "r") as f: + filecontents = "".join(f) - filecontents = filecontents.replace(mypath, curr_path_marker) + filecontents = filecontents.replace(mypath, curr_path_marker) - with open(filename, "w") as f: - print(filecontents, file=f) + with open(filename, "w") as f: + print(filecontents, file=f) def run_input(name): - os.system("python3 "+REPL_PATH+" <"+inpath+"/"+name+" >"+tgtpath+"/"+name) - fix_file_paths(tgtpath+"/"+name, curr_path_marker) + os.system(f"python3 {REPL_PATH} <{inpath}/{name} >{tgtpath}/{name}") + fix_file_paths(tgtpath + "/" + name, curr_path_marker) def run_inputs(): - print("making the target outputs!") - for n in all_names: - run_input(n) + print("making the target outputs!") + for n in all_names: + run_input(n) def run_broken_lib(lib): - os.system("cp "+joinpath(libspath, lib)+" "+RASPLIB_PATH) - os.system("python3 "+REPL_PATH+" <"+joinpath(libtestspath, - "empty.txt") + " >"+joinpath(libtgtspath, lib)) + os.system("cp " + joinpath(libspath, lib) + " " + RASPLIB_PATH) + readpath = joinpath(libtestspath, "empty.txt") + writepath = joinpath(libtgtspath, lib) + os.system("python3 " + REPL_PATH + " <" + readpath + " >" + writepath) real_rasplib_safe_place = "make_tgts_helper/temp" safe_rasplib_name = "safe_rasplib.rasp" +rasplib_save_loc = joinpath(real_rasplib_safe_place, safe_rasplib_name) def save_rasplib(): - if not os.path.exists(real_rasplib_safe_place): - os.makedirs(real_rasplib_safe_place) - os.system("mv "+RASPLIB_PATH+" " + - joinpath(real_rasplib_safe_place, safe_rasplib_name)) + if not os.path.exists(real_rasplib_safe_place): + os.makedirs(real_rasplib_safe_place) + os.system("mv " + RASPLIB_PATH + " " + rasplib_save_loc) def restore_rasplib(): - os.system("mv "+joinpath(real_rasplib_safe_place, - safe_rasplib_name)+" "+RASPLIB_PATH) + os.system("mv " + rasplib_save_loc + " " + RASPLIB_PATH) def run_broken_libs(): - print("making the broken lib targets!") - save_rasplib() - all_libs = things_in_path(libspath) - for lib in all_libs: - run_broken_lib(lib) - restore_rasplib() + print("making the broken lib targets!") + save_rasplib() + all_libs = things_in_path(libspath) + for lib in all_libs: + run_broken_lib(lib) + restore_rasplib() if __name__ == "__main__": - run_inputs() - run_broken_libs() + run_inputs() + run_broken_libs() diff --git a/tests/test_all.py b/tests/test_all.py index a9d9f0d..aad5649 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,70 +1,72 @@ import os from make_tgts import fix_file_paths, curr_path_marker, joinpath, \ - things_in_path, inpath, outpath, tgtpath, libtestspath, libspath, \ - libtgtspath, liboutspath, save_rasplib, restore_rasplib + things_in_path, inpath, outpath, tgtpath, libtestspath, libspath, \ + libtgtspath, liboutspath, save_rasplib, restore_rasplib def check_equal(f1, f2): - res = os.system("diff "+f1+" "+f2) - return res == 0 # 0 = diff found no differences + res = os.system("diff " + f1 + " " + f2) + return res == 0 # 0 = diff found no differences for p in [outpath, liboutspath]: - if not os.path.exists(p): - os.makedirs(p) + if not os.path.exists(p): + os.makedirs(p) def run_input(name): - os.system("python3 -m RASP_support <" + - joinpath(inpath, name)+" >"+joinpath(outpath, name)) - fix_file_paths(joinpath(outpath, name), curr_path_marker) - return check_equal(joinpath(outpath, name), joinpath(tgtpath, name)) + readpath = joinpath(inpath, name) + writepath = joinpath(outpath, name) + os.system("python3 -m RASP_support <" + readpath + " >" + writepath) + fix_file_paths(writepath, curr_path_marker) + return check_equal(writepath, joinpath(tgtpath, name)) def run_inputs(): - all_names = things_in_path(inpath) - passed = True - for n in all_names: - success = run_input(n) - print("input", n, "passed:", success) - if not success: - passed = False - return passed + all_names = things_in_path(inpath) + passed = True + for n in all_names: + success = run_input(n) + print("input", n, "passed:", success) + if not success: + passed = False + return passed def test_broken_lib(lib): - inlib, outlib = lib, lib.replace(".rasp", ".txt") - os.system("cp "+joinpath(libspath, inlib)+" RASP_support/rasplib.rasp") - os.system("python3 -m RASP_support <" + joinpath(libtestspath, - "empty.txt") + " >" + joinpath(liboutspath, outlib)) - return check_equal(joinpath(liboutspath, outlib), - joinpath(libtgtspath, outlib)) + inlib, outlib = lib, lib.replace(".rasp", ".txt") + os.system("cp " + joinpath(libspath, inlib) + " RASP_support/rasplib.rasp") + readpath = joinpath(libtestspath, "empty.txt") + writepath = joinpath(liboutspath, outlib) + os.system("python3 -m RASP_support <" + readpath + " >" + writepath) + return check_equal( + joinpath(liboutspath, outlib), joinpath(libtgtspath, outlib)) def run_broken_libs(): - save_rasplib() - all_libs = things_in_path(libspath) - passed = True - for lib in all_libs: - success = test_broken_lib(lib) - print("lib", lib, "passed (i.e., properly errored):", success) - if not success: - passed = False - restore_rasplib() - return passed + save_rasplib() + all_libs = things_in_path(libspath) + passed = True + for lib in all_libs: + success = test_broken_lib(lib) + print("lib", lib, "passed (i.e., properly errored):", success) + if not success: + passed = False + restore_rasplib() + return passed if __name__ == "__main__": - passed_inputs = run_inputs() - print("passed all inputs:", passed_inputs) - print("=====\n\n=====") - passed_broken_libs = run_broken_libs() - print("properly reports broken libs:", passed_broken_libs) - print("=====\n\n=====") + passed_inputs = run_inputs() + print("passed all inputs:", passed_inputs) + print("=====\n\n=====") + passed_broken_libs = run_broken_libs() + print("properly reports broken libs:", passed_broken_libs) + print("=====\n\n=====") - passed_everything = False not in [passed_inputs, passed_broken_libs] - print("=====\npassed everything:", passed_everything) - if passed_everything: - exit(0) - else: - exit(1) + passed_everything = False not in [passed_inputs, passed_broken_libs] + print("=====\npassed everything:", passed_everything) + if passed_everything: + exit(0) + else: + exit(1)