diff --git a/src/cnlpt/cnlp_args.py b/src/cnlpt/cnlp_args.py index 969a3abe..0c87336e 100644 --- a/src/cnlpt/cnlp_args.py +++ b/src/cnlpt/cnlp_args.py @@ -174,7 +174,7 @@ class ModelArguments: ) }, ) - ignore_existing_classifers: bool = field( + ignore_existing_classifiers: bool = field( default=False, metadata={ "help": ( diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index d239fcc8..9b5f676b 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -143,15 +143,6 @@ def main( model_name = model_args.model hierarchical = model_name == 'hier' - if ( - hierarchical - and (model_args.keep_existing_classifiers == model_args.ignore_existing_classifers) # XNOR - ): - raise ValueError( - "For hierarchical model, one of --keep_existing_classifiers or " - "--ignore_existing_classifers flags should be selected." - ) - # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -260,6 +251,13 @@ def main( freeze=training_args.freeze, ) else: + if ( + hierarchical + and (model_args.keep_existing_classifiers == model_args.ignore_existing_classifiers) # XNOR + ): + raise ValueError( + "For continued training of a cnlpt hierarchical model, one of --keep_existing_classifiers or --ignore_existing_classifiers flags should be selected." + ) # use a checkpoint from an existing model AutoModel.register(CnlpConfig, HierarchicalModel)