From 7b9fc922c3af0463380a1afa373581da8e83019a Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Thu, 12 Jul 2018 17:21:08 +0200 Subject: [PATCH] Make rouge metrics available in sockeye.evaluate CLI (#471) * Make rouge metrics available in sockeye.evaluate CLI * Add chrf to optimizable metrics * fix * compute rouge in system tests * fix --- CHANGELOG.md | 5 ++ sockeye/__init__.py | 2 +- sockeye/arguments.py | 22 +++++--- sockeye/constants.py | 8 +-- sockeye/evaluate.py | 120 +++++++++++++++++++++++-------------------- sockeye/train.py | 3 -- sockeye/training.py | 4 +- test/common.py | 4 +- 8 files changed, 93 insertions(+), 75 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e82f879d7..d976a829a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.35] +### Added +- ROUGE scores are now available in `sockeye-evaluate`. +- Enabled CHRF as an early-stopping metric. + ## [1.18.34] ### Added - Added support for `--beam-search-stop first` for decoding jobs with `--batch-size > 1`. diff --git a/sockeye/__init__.py b/sockeye/__init__.py index ce8b15e89..de6152484 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.34' +__version__ = '1.18.35' diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 063217a2e..8e2556232 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -21,9 +21,9 @@ import yaml from typing import Any, Callable, Dict, List, Tuple, Optional -from sockeye.lr_scheduler import LearningRateSchedulerFixedStep from . import constants as C from . import data_io +from .lr_scheduler import LearningRateSchedulerFixedStep class ConfigArgumentParser(argparse.ArgumentParser): @@ -32,7 +32,7 @@ class ConfigArgumentParser(argparse.ArgumentParser): The option --config is added automatically and expects a YAML serialized dictionary, similar to the return value of parse_args(). Command line - parameters have precendence over config file values. Usage should be + parameters have precedence over config file values. Usage should be transparent, just substitute argparse.ArgumentParser with this class. Extended from @@ -1051,7 +1051,8 @@ def add_max_output_cli_args(params): params.add_argument('--max-output-length', type=int, default=None, - help='Maximum number of words to generate during translation. If None, it will be computed automatically. Default: %(default)s.') + help='Maximum number of words to generate during translation. ' + 'If None, it will be computed automatically. Default: %(default)s.') def add_inference_args(params): @@ -1104,11 +1105,13 @@ def add_inference_args(params): type=float, default=0, help='Pruning threshold for beam search. All hypotheses with scores not within ' - 'this amount of the best finished hypothesis are discarded (0 = off). Default: %(default)s.') + 'this amount of the best finished hypothesis are discarded (0 = off). ' + 'Default: %(default)s.') decode_params.add_argument('--beam-search-stop', choices=[C.BEAM_SEARCH_STOP_ALL, C.BEAM_SEARCH_STOP_FIRST], default=C.BEAM_SEARCH_STOP_ALL, - help='Stopping criteria. Quit when (all) hypotheses are finished or when a finished hypothesis is in (first) position. Default: %(default)s.') + help='Stopping criteria. Quit when (all) hypotheses are finished ' + 'or when a finished hypothesis is in (first) position. Default: %(default)s.') decode_params.add_argument('--batch-size', type=int_greater_or_equal(1), default=1, @@ -1160,7 +1163,8 @@ def add_inference_args(params): decode_params.add_argument('--avoid-list', type=str, default=None, - help="Specify a file containing phrases (pre-processed, one per line) to block from the output. Default: %(default)s.") + help="Specify a file containing phrases (pre-processed, one per line) to block " + "from the output. Default: %(default)s.") decode_params.add_argument('--strip-unknown-words', action='store_true', default=False, @@ -1191,6 +1195,7 @@ def add_inference_args(params): help='EXPERIMENTAL: may be changed or removed in future. Overrides training dtype of ' 'encoders and decoders during inference. Default: %(default)s') + def add_evaluate_args(params): eval_params = params.add_argument_group("Evaluate parameters") eval_params.add_argument('--references', '-r', @@ -1201,9 +1206,10 @@ def add_evaluate_args(params): type=file_or_stdin(), default=[sys.stdin], nargs='+', - help="File(s) with hypotheses. If none will read from stdin. Default: %(default)s.") + help="File(s) with hypotheses. If none will read from stdin. Default: stdin.") eval_params.add_argument('--metrics', nargs='+', + choices=C.EVALUATE_METRICS, default=[C.BLEU, C.CHRF], help='List of metrics to compute. Default: %(default)s.') eval_params.add_argument('--sentence', '-s', @@ -1212,7 +1218,7 @@ def add_evaluate_args(params): eval_params.add_argument('--offset', type=float, default=0.01, - help="Numerical value of the offset of zero n-gram counts. Default: %(default)s.") + help="Numerical value of the offset of zero n-gram counts for BLEU. Default: %(default)s.") eval_params.add_argument('--not-strict', '-n', action="store_true", help="Do not fail if number of hypotheses does not match number of references. " diff --git a/sockeye/constants.py b/sockeye/constants.py index aac152044..85870967a 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -348,9 +348,11 @@ ROUGE_L_VAL = ROUGEL + "-val" AVG_TIME = "avg-sec-per-sent-val" DECODING_TIME = "decode-walltime-val" -METRICS = [PERPLEXITY, ACCURACY, BLEU, ROUGE1] -METRIC_MAXIMIZE = {ACCURACY: True, BLEU: True, ROUGE1: True, PERPLEXITY: False} -METRIC_WORST = {ACCURACY: 0.0, BLEU: 0.0, ROUGE1: 0.0, PERPLEXITY: np.inf} +METRICS = [PERPLEXITY, ACCURACY, BLEU, CHRF, ROUGE1] +METRIC_MAXIMIZE = {ACCURACY: True, BLEU: True, CHRF: True, ROUGE1: True, PERPLEXITY: False} +METRIC_WORST = {ACCURACY: 0.0, BLEU: 0.0, CHRF: 0.0, ROUGE1: 0.0, PERPLEXITY: np.inf} +METRICS_REQUIRING_DECODER = [BLEU, CHRF, ROUGE1, ROUGE2, ROUGEL] +EVALUATE_METRICS = [BLEU, CHRF, ROUGE1, ROUGE2, ROUGEL] # loss CROSS_ENTROPY = 'cross-entropy' diff --git a/sockeye/evaluate.py b/sockeye/evaluate.py index adfd2556a..1a4a105b7 100644 --- a/sockeye/evaluate.py +++ b/sockeye/evaluate.py @@ -12,21 +12,23 @@ # permissions and limitations under the License. """ -Evaluation CLI. Prints corpus BLEU +Evaluation CLI. """ import argparse import logging import sys -import numpy as np -from typing import Iterable, Optional from collections import defaultdict +from functools import partial +from typing import Callable, Iterable, Dict, List, Tuple, Optional + +import numpy as np from contrib import sacrebleu, rouge -from sockeye.log import setup_main_logger, log_sockeye_version from . import arguments from . import constants as C from . import data_io from . import utils +from .log import setup_main_logger, log_sockeye_version logger = setup_main_logger(__name__, file_logging=False) @@ -54,35 +56,39 @@ def raw_corpus_chrf(hypotheses: Iterable[str], references: Iterable[str]) -> flo return sacrebleu.corpus_chrf(hypotheses, references, order=sacrebleu.CHRF_ORDER, beta=sacrebleu.CHRF_BETA, remove_whitespace=True) + def raw_corpus_rouge1(hypotheses: Iterable[str], references: Iterable[str]) -> float: - """ - Simple wrapper around ROUGE-1 implementation. + """ + Simple wrapper around ROUGE-1 implementation. + + :param hypotheses: Hypotheses stream. + :param references: Reference stream. + :return: ROUGE-1 score as float between 0 and 1. + """ + return rouge.rouge_1(hypotheses, references) - :param hypotheses: Hypotheses stream. - :param references: Reference stream. - :return: ROUGE-1 score as float between 0 and 1. - """ - return rouge.rouge_1(hypotheses, references) def raw_corpus_rouge2(hypotheses: Iterable[str], references: Iterable[str]) -> float: - """ - Simple wrapper around ROUGE-2 implementation. + """ + Simple wrapper around ROUGE-2 implementation. + + :param hypotheses: Hypotheses stream. + :param references: Reference stream. + :return: ROUGE-2 score as float between 0 and 1. + """ + return rouge.rouge_2(hypotheses, references) - :param hypotheses: Hypotheses stream. - :param references: Reference stream. - :return: ROUGE-2 score as float between 0 and 1. - """ - return rouge.rouge_2(hypotheses, references) def raw_corpus_rougel(hypotheses: Iterable[str], references: Iterable[str]) -> float: - """ - Simple wrapper around ROUGE-1 implementation. + """ + Simple wrapper around ROUGE-1 implementation. + + :param hypotheses: Hypotheses stream. + :param references: Reference stream. + :return: ROUGE-L score as float between 0 and 1. + """ + return rouge.rouge_l(hypotheses, references) - :param hypotheses: Hypotheses stream. - :param references: Reference stream. - :return: ROUGE-L score as float between 0 and 1. - """ - return rouge.rouge_l(hypotheses, references) def main(): params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with ' @@ -103,55 +109,57 @@ def main(): references = [' '.join(e) for e in data_io.read_content(args.references)] all_hypotheses = [[h.strip() for h in hypotheses] for hypotheses in args.hypotheses] - metrics = args.metrics - logger.info("%d hypotheses | %d references", len(all_hypotheses), len(references)) - if not args.not_strict: for hypotheses in all_hypotheses: utils.check_condition(len(hypotheses) == len(references), "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses), len(references))) - metric_info = [] - for metric in metrics: - metric_info.append("%s\t(s_opt)" % metric) + logger.info("%d hypothesis set(s) | %d hypotheses | %d references", + len(all_hypotheses), len(all_hypotheses[0]), len(references)) + + metric_info = ["%s\t(s_opt)" % name for name in args.metrics] logger.info("\t".join(metric_info)) + metrics = [] # type: List[Tuple[str, Callable]] + for name in args.metrics: + if name == C.BLEU: + func = partial(raw_corpus_bleu, offset=args.offset) + elif name == C.CHRF: + func = raw_corpus_chrf + elif name == C.ROUGE1: + func = raw_corpus_rouge1 + elif name == C.ROUGE2: + func = raw_corpus_rouge2 + elif name == C.ROUGEL: + func = raw_corpus_rougel + else: + raise ValueError("Unknown metric %s." % name) + metrics.append((name, func)) + if not args.sentence: - scores = defaultdict(list) + scores = defaultdict(list) # type: Dict[str, List[float]] for hypotheses in all_hypotheses: - for metric in metrics: - if metric == C.BLEU: - score = raw_corpus_bleu(hypotheses, references, args.offset) - elif metric == C.CHRF: - score = raw_corpus_chrf(hypotheses, references) - else: - raise ValueError("Unknown metric %s." % metric) - scores[metric].append(score) + for name, metric in metrics: + scores[name].append(metric(hypotheses, references)) _print_mean_std_score(metrics, scores) else: for hypotheses in all_hypotheses: for h, r in zip(hypotheses, references): - scores = defaultdict(list) - for metric in metrics: - if metric == C.BLEU: - score = raw_corpus_bleu([h], [r], args.offset) - elif metric == C.CHRF: - score = raw_corpus_chrf(h, r) - else: - raise ValueError("Unknown metric %s." % metric) - scores[metric].append(score) + scores = defaultdict(list) # type: Dict[str, List[float]] + for name, metric in metrics: + scores[name].append(metric([h], [r])) _print_mean_std_score(metrics, scores) -def _print_mean_std_score(metrics, scores): - scores_mean_std = [] - for metric in metrics: - if len(scores[metric]) > 1: - score_mean = np.asscalar(np.mean(scores[metric])) - score_std = np.asscalar(np.std(scores[metric], ddof=1)) +def _print_mean_std_score(metrics: List[Tuple[str, Callable]], scores: Dict[str, List[float]]): + scores_mean_std = [] # type: List[str] + for name, _ in metrics: + if len(scores[name]) > 1: + score_mean = np.asscalar(np.mean(scores[name])) + score_std = np.asscalar(np.std(scores[name], ddof=1)) scores_mean_std.append("%.3f\t%.3f" % (score_mean, score_std)) else: - score = scores[metric][0] + score = scores[name][0] scores_mean_std.append("%.3f\t(-)" % score) print("\t".join(scores_mean_std)) diff --git a/sockeye/train.py b/sockeye/train.py index d549518b0..3cea988f6 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -78,9 +78,6 @@ def check_arg_compatibility(args: argparse.Namespace): :param args: Arguments as returned by argparse. """ - check_condition(args.optimized_metric == C.BLEU or args.optimized_metric == C.ROUGE1 or args.optimized_metric in args.metrics, - "Must optimize either BLEU, ROUGE or one of tracked metrics (--metrics)") - if args.encoder == C.TRANSFORMER_TYPE: check_condition(args.transformer_model_size[0] == args.num_embed[0], "Source embedding size must match transformer model size: %s vs. %s" diff --git a/sockeye/training.py b/sockeye/training.py index 006ef55d8..de64ea2f0 100644 --- a/sockeye/training.py +++ b/sockeye/training.py @@ -905,8 +905,8 @@ def _check_args(self, utils.check_condition(early_stopping_metric in C.METRICS, "Unsupported early-stopping metric: %s" % early_stopping_metric) - if early_stopping_metric == C.BLEU: - utils.check_condition(cp_decoder is not None, "%s requires CheckpointDecoder" % C.BLEU) + if early_stopping_metric in C.METRICS_REQUIRING_DECODER: + utils.check_condition(cp_decoder is not None, "%s requires CheckpointDecoder" % early_stopping_metric) def _save_params(self): """ diff --git a/test/common.py b/test/common.py index 9cd2ce5d1..405e3808d 100644 --- a/test/common.py +++ b/test/common.py @@ -459,11 +459,11 @@ def run_train_translate(train_params: str, if restrict_lexicon: bleu_restrict = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01) - # Run BLEU cli + # Run evaluate cli eval_params = "{} {} ".format(sockeye.evaluate.__file__, _EVAL_PARAMS_COMMON.format(hypotheses=out_path, references=test_target_path, - metrics="bleu chrf", + metrics="bleu chrf rouge1", quiet=quiet_arg), ) with patch.object(sys, "argv", eval_params.split()): sockeye.evaluate.main()