Skip to content

Commit

Permalink
Merge pull request #159 from nicolay-r/0.21.0-rc
Browse files Browse the repository at this point in the history
0.21.0 rc
  • Loading branch information
nicolay-r authored Aug 15, 2021
2 parents 989d984 + 14ea801 commit 9b83f93
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 32 deletions.
7 changes: 4 additions & 3 deletions tests/contrib/bert/test_output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from os.path import join, dirname, abspath
import sys
import unittest

Expand All @@ -13,9 +14,9 @@

class TestOutputFormatters(unittest.TestCase):

__input_samples_filepath = u"data/test_sample_3l.tsv.gz"

__google_bert_output_filepath_sample = u"data/test_google_bert_output_3l.tsv"
__current_dir = dirname(__file__)
__input_samples_filepath = join(__current_dir, u"data/test_sample_3l.tsv.gz")
__google_bert_output_filepath_sample = join(__current_dir, u"data/test_google_bert_output_3l.tsv")

def test_google_bert_output_formatter(self):
row_ids_provider = MultipleIDProvider()
Expand Down
14 changes: 8 additions & 6 deletions tests/contrib/networks/test_samples_iter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from os.path import dirname, join

import pandas as pd
import gzip
import sys
Expand All @@ -11,17 +13,18 @@
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.sample import InputSample

from tests.contrib.networks.labels import TestThreeLabelScaler


class TestSamplesIteration(unittest.TestCase):

__show_examples = False
__show_shifted_examples = False

def __get_local_dir(self, local_filepath):
return join(dirname(__file__), local_filepath)

def test_check_all_samples(self):
vocab_filepath = u"test_data/vocab.txt.gz"
samples_filepath = u"test_data/sample-train.tsv.gz"
vocab_filepath = self.__get_local_dir(u"test_data/vocab.txt.gz")
samples_filepath = self.__get_local_dir(u"test_data/sample-train.tsv.gz")
words_vocab = self.__read_vocab(vocab_filepath)
config = DefaultNetworkConfig()
config.modify_terms_per_context(50)
Expand Down Expand Up @@ -78,11 +81,10 @@ def __test_core(self, words_vocab, config, samples_filepath):
assert(isinstance(samples_filepath, unicode))

samples = []
labels_scaler = TestThreeLabelScaler()
for i, row in enumerate(self.__iter_tsv_gzip(input_file=samples_filepath)):

# Perform row parsing process.
row = ParsedSampleRow(row, labels_scaler=labels_scaler)
row = ParsedSampleRow(row)

subj_ind = row.SubjectIndex
obj_ind = row.ObjectIndex
Expand Down
6 changes: 3 additions & 3 deletions tests/contrib/networks/test_tf_ctx_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

sys.path.append('../../../')

from arekit.common.languages.ru.pos_service import PartOfSpeechTypesService
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig

from tests.contrib.networks.tf_networks.supported import get_supported
from tests.contrib.networks.tf_networks.utils import init_config

from arekit.common.languages.ru.pos_service import PartOfSpeechTypesService
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig


class TestContextNetworkCompilation(unittest.TestCase):

Expand Down
26 changes: 15 additions & 11 deletions tests/contrib/networks/test_tf_ctx_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@

sys.path.append('../../../')

from tests.contrib.networks.labels import TestThreeLabelScaler
from tests.contrib.networks.tf_networks.supported import get_supported
from tests.contrib.networks.tf_networks.utils import init_config

from arekit.common.experiment.data_type import DataType
from arekit.common.labels.scaler import BaseLabelScaler

from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.sample import InputSample
from arekit.contrib.networks.core.feeding.bags.bag import Bag
from arekit.contrib.networks.core.feeding.batch.base import MiniBatch
from arekit.contrib.networks.core.nn import NeuralNetwork

from tests.contrib.networks.tf_networks.supported import get_supported
from tests.contrib.networks.tf_networks.utils import init_config


class TestContextNetworkFeeding(unittest.TestCase):

Expand All @@ -29,16 +31,13 @@ def init_session():
return sess

@staticmethod
def create_minibatch(config, labels_count):
def __create_minibatch(config, labels_scaler):
assert(isinstance(config, DefaultNetworkConfig))
assert(isinstance(labels_count, int))

l_uint_min = 0
l_uint_max = labels_count - 1
assert(isinstance(labels_scaler, BaseLabelScaler))

bags = []
for i in range(config.BagsPerMinibatch):
uint_label = random.randint(l_uint_min, l_uint_max)
uint_label = random.randint(0, labels_scaler.LabelsCount)
bag = Bag(uint_label=uint_label)
for j in range(config.BagSize):
bag.add_sample(InputSample._generate_test(config))
Expand All @@ -48,16 +47,18 @@ def create_minibatch(config, labels_count):

@staticmethod
def run_feeding(network, network_config, create_minibatch_func, logger,
labels_scaler,
display_hidden_values=True,
display_idp_values=True):
assert(isinstance(network, NeuralNetwork))
assert(isinstance(network_config, DefaultNetworkConfig))
assert(isinstance(labels_scaler, BaseLabelScaler))
assert(callable(create_minibatch_func))

init_config(network_config)
# Init network.
network.compile(config=network_config, reset_graph=True, graph_seed=42)
minibatch = create_minibatch_func(config=network_config, labels_count=3)
minibatch = create_minibatch_func(config=network_config, labels_scaler=labels_scaler)

network_optimiser = network_config.Optimiser.minimize(network.Cost)

Expand Down Expand Up @@ -107,11 +108,14 @@ def test(self):
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

labels_scaler = TestThreeLabelScaler()

for cfg, network in get_supported():
logger.debug("Feed to the network: {}".format(type(network)))
self.run_feeding(network=network,
network_config=cfg,
create_minibatch_func=self.create_minibatch,
create_minibatch_func=self.__create_minibatch,
labels_scaler=labels_scaler,
logger=logger)


Expand Down
20 changes: 12 additions & 8 deletions tests/contrib/networks/test_tf_mi_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

sys.path.append('../../../')

from arekit.common.labels.scaler import BaseLabelScaler
from tests.contrib.networks.labels import TestThreeLabelScaler
from tests.contrib.networks.test_tf_ctx_feed import TestContextNetworkFeeding
from tests.contrib.networks.tf_networks.supported import get_supported

from arekit.contrib.networks.multi.configurations.att_self import AttSelfOverSentencesConfig
from arekit.contrib.networks.multi.architectures.att_self import AttSelfOverSentences
from arekit.contrib.networks.core.feeding.bags.bag import Bag
from arekit.contrib.networks.core.feeding.batch.multi import MultiInstanceMiniBatch

from arekit.contrib.networks.multi.configurations.max_pooling import MaxPoolingOverSentencesConfig
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.sample import InputSample
from arekit.contrib.networks.multi.architectures.max_pooling import MaxPoolingOverSentences

from tests.contrib.networks.labels import TestNeutralLabel
from tests.contrib.networks.test_tf_ctx_feed import TestContextNetworkFeeding
from tests.contrib.networks.tf_networks.supported import get_supported
from arekit.common.labels.scaler import BaseLabelScaler


class TestMultiInstanceFeed(unittest.TestCase):
Expand All @@ -27,12 +28,12 @@ def __create_minibatch(config, labels_scaler):
assert(isinstance(config, DefaultNetworkConfig))
assert(isinstance(labels_scaler, BaseLabelScaler))
bags = []
label = TestNeutralLabel()
no_label = labels_scaler.get_no_label_instance()
empty_sample = InputSample.create_empty(terms_per_context=config.TermsPerContext,
frames_per_context=config.FramesPerContext,
synonyms_per_context=config.SynonymsPerContext)
for i in range(config.BagsPerMinibatch):
bag = Bag(label)
bag = Bag(labels_scaler.label_to_uint(no_label))
for j in range(config.BagSize):
bag.add_sample(empty_sample)
bags.append(bag)
Expand All @@ -43,13 +44,15 @@ def __create_minibatch(config, labels_scaler):
def multiinstances_supported(ctx_config, ctx_network):
return [
(MaxPoolingOverSentencesConfig(ctx_config), MaxPoolingOverSentences(ctx_network)),
# (AttSelfOverSentencesConfig(ctx_config), AttSelfOverSentences(ctx_network))
(AttSelfOverSentencesConfig(ctx_config), AttSelfOverSentences(ctx_network))
]

def test(self):
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

labels_scaler = TestThreeLabelScaler()

for ctx_config, ctx_network in get_supported():
for config, network in self.multiinstances_supported(ctx_config, ctx_network):
logger.info(type(network))
Expand All @@ -58,6 +61,7 @@ def test(self):
network_config=config,
create_minibatch_func=self.__create_minibatch,
logger=logger,
labels_scaler=labels_scaler,
display_idp_values=False)


Expand Down
6 changes: 6 additions & 0 deletions tests/run_gen_release_notes.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
docker run -it --rm -v "$(pwd)":/usr/local/src/your-app/ githubchangeloggenerator/github-changelog-generator \
-u nicolay-r \
-p AREkit \
--token <GITHUB_TOKEN> \
--since-tag v0.20.5-rc
2 changes: 1 addition & 1 deletion tests/run_test_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Performing a quick library installation.
# https://stackoverflow.com/questions/19048732/python-setup-py-develop-vs-install
sudo pip install -e ../ --no-deps
pip install -e ../ --no-deps

# Run all unit tests.
python2.7 -m unittest discover .

0 comments on commit 9b83f93

Please sign in to comment.