From 99c0a2fbb092f5ceff92851457bc37e98e7de6ab Mon Sep 17 00:00:00 2001 From: Tim Miller Date: Fri, 27 Sep 2024 18:59:17 -0400 Subject: [PATCH] Updated black version and ran black. --- .pre-commit-config.yaml | 2 +- src/cnlpt/train_system.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6841d79c..5c2701e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.7.0 + rev: 24.8.0 hooks: - id: black - repo: https://github.com/pycqa/flake8 diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index 2c12a444..b8dc34f9 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -545,16 +545,16 @@ def main( # steps per epoch factors in gradient accumulation steps (as compared to batches_per_epoch above which doesn't) steps_per_epoch = int(total_steps // training_args.num_train_epochs) training_args.eval_steps = steps_per_epoch // training_args.evals_per_epoch - training_args.evaluation_strategy = ( - training_args.eval_strategy - ) = IntervalStrategy.STEPS + training_args.evaluation_strategy = training_args.eval_strategy = ( + IntervalStrategy.STEPS + ) # This will save model per epoch # training_args.save_strategy = IntervalStrategy.EPOCH elif training_args.do_eval: logger.info("Evaluation strategy not specified so evaluating every epoch") - training_args.evaluation_strategy = ( - training_args.eval_strategy - ) = IntervalStrategy.EPOCH + training_args.evaluation_strategy = training_args.eval_strategy = ( + IntervalStrategy.EPOCH + ) current_prediction_packet = deque() @@ -662,9 +662,9 @@ def compute_metrics_fn(p: EvalPrediction): "w", ) as f: config_dict = model_args.to_dict() - config_dict[ - "label_dictionary" - ] = dataset.get_labels() + config_dict["label_dictionary"] = ( + dataset.get_labels() + ) config_dict["task_names"] = task_names json.dump(config_dict, f) for task_ind, task_name in enumerate(metrics):