Skip to content

Commit

Permalink
Fixes duplicate outcomes when removes_duplicates=True (#148)
Browse files Browse the repository at this point in the history
* enforces unique outcomes in preprocess.py; bumps version number

* adds test for remove_duplicates in test_preprocess.py
  • Loading branch information
kuchenrolle authored and derNarr committed Mar 2, 2018
1 parent 9b47606 commit 77678f9
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyndl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
__author__ = ('Konstantin Sering, Marc Weitz, '
'David-Elias Künstle, Lennard Schneider')
__author_email__ = '[email protected]'
__version__ = '0.5.0'
__version__ = '0.5.1'
__license__ = 'MIT'
__description__ = ('Naive discriminative learning implements learning and '
'classification models based on the Rescorla-Wagner '
Expand Down
35 changes: 17 additions & 18 deletions pyndl/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def ngrams_to_word(occurrences, n_chars, outfile, remove_duplicates=True):
if not ngrams or not occurrence:
continue
if remove_duplicates:
outfile.write("{}\t{}\n".format("_".join(set(ngrams)), occurrence))
else:
outfile.write("{}\t{}\n".format("_".join(ngrams), occurrence))
ngrams = set(ngrams)
occurrence = "_".join(set(occurrence.split("_")))
outfile.write("{}\t{}\n".format("_".join(ngrams), occurrence))


def process_occurrences(occurrences, outfile, *,
Expand Down Expand Up @@ -132,9 +132,9 @@ def process_occurrences(occurrences, outfile, *,
if not cues:
continue
if remove_duplicates:
outfile.write("{}\t{}\n".format("_".join(set(cues.split("_"))), outcomes))
else:
outfile.write("{}\t{}\n".format(cues, outcomes))
cues = "_".join(set(cues.split("_")))
outcomes = "_".join(set(outcomes.split("_")))
outfile.write("{}\t{}\n".format(cues, outcomes))
else:
raise NotImplementedError('cue_structure=%s is not implemented yet.' % cue_structure)

Expand Down Expand Up @@ -245,19 +245,16 @@ def gen_occurrences(words):
"""
if event_structure == 'consecutive_words':
occurrences = list()
cur_words = list()
ii = 0
while True:
if ii < len(words):
cur_words.append(words[ii])
if ii >= len(words) or ii >= number_of_words:
# remove the first word
cur_words = cur_words[1:]
# can't have more consecutive words than total words
length = min(number_of_words, len(words))
# slide window over list of words
for ii in range(1 - length, len(words)):
# no consecutive words before first word
start = max(ii, 0)
# no consecutive words after last word
end = min(ii + length, len(words))
# append (cues, outcomes) with empty outcomes
occurrences.append(("_".join(cur_words), ''))
ii += 1
if not cur_words:
break
occurrences.append(("_".join(words[start:end]), ""))
return occurrences
# for words = (A, B, C, D); before = 2, after = 1
# make: (B, A), (A_C, B), (A_B_D, C), (B_C, D)
Expand All @@ -274,6 +271,8 @@ def gen_occurrences(words):
elif event_structure == 'line':
# (cues, outcomes) with empty outcomes
return [('_'.join(words), ''), ]
else:
raise ValueError('gen_occurrences should be one of {"consecutive_words", "word_to_word", "line"}')

def process_line(line):
"""processes one line of text."""
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ max-line-length = 120

[pylint]
max-line-length = 120
good-names = nn, ii, _
good-names = nn, ii, _, jj
extension-pkg-whitelist=numpy,pyndl.ndl_parallel
ignore=pyndl/ndl_parallel
disable=E1101
Expand Down
Binary file modified tests/reference/event_file_bigrams_to_word.tab.gz
Binary file not shown.
Binary file modified tests/reference/event_file_trigrams_to_word.tab.gz
Binary file not shown.
Binary file modified tests/reference/event_file_trigrams_to_word_line_based.tab.gz
Binary file not shown.
47 changes: 47 additions & 0 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,53 @@ def test_bigrams_to_word():
os.remove(event_file)


def test_remove_duplicates():
event_file_noduplicates = os.path.join(TEST_ROOT, "temp/event_file_bigrams_to_word_noduplicates.tab.gz")
event_file_duplicates = os.path.join(TEST_ROOT, "temp/event_file_bigrams_to_word_duplicates.tab.gz")
create_event_file(RESOURCE_FILE, event_file_duplicates,
context_structure="document",
event_structure="consecutive_words",
event_options=(3, ),
cue_structure="bigrams_to_word",
remove_duplicates=False)
create_event_file(RESOURCE_FILE, event_file_noduplicates,
context_structure="document",
event_structure="consecutive_words",
event_options=(3, ),
cue_structure="bigrams_to_word",
remove_duplicates=True)

with gzip.open(event_file_noduplicates, "rt") as new_file:
lines_new = new_file.readlines()
with gzip.open(event_file_duplicates, "rt") as reference:
lines_reference = reference.readlines()
assert len(lines_new) == len(lines_reference)
n_cues_unequal = 0
n_outcomes_unequal = 0
for ii, line in enumerate(lines_new):
cues, outcomes = line.strip().split('\t')
cues = sorted(cues.split('_'))
outcomes = sorted(outcomes.split('_'))
ref_cues, ref_outcomes = lines_reference[ii].strip().split('\t')
ref_cues = sorted(ref_cues.split('_'))
ref_outcomes = sorted(ref_outcomes.split('_'))
if len(cues) != len(ref_cues):
n_cues_unequal += 1
if len(outcomes) != len(ref_outcomes):
n_outcomes_unequal += 1
# there should be no duplicates in (noduplicates)
assert len(cues) == len(set(cues))
assert len(outcomes) == len(set(outcomes))
# after making each list unique it should be the same
assert set(cues) == set(ref_cues)
assert set(outcomes) == set(ref_outcomes)
assert n_cues_unequal == 1098
assert n_outcomes_unequal == 66

os.remove(event_file_noduplicates)
os.remove(event_file_duplicates)


def test_word_to_word():
event_file = os.path.join(TEST_ROOT, "temp/event_file_word_to_word.tab.gz")
reference_file = os.path.join(TEST_ROOT, "reference/event_file_word_to_word.tab.gz")
Expand Down

0 comments on commit 77678f9

Please sign in to comment.