From 9bdaedaa5d211c1ec8191a12d321cc7299cfa3a7 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Mon, 22 Jul 2024 14:37:04 +0100 Subject: [PATCH 01/42] Creating debug_log and transferring in depth log info into there --- mace/cli/run_train.py | 8 ++++---- mace/data/utils.py | 12 ++++++------ mace/tools/scripts_utils.py | 4 ++-- mace/tools/utils.py | 31 +++++++++++++++++++++++-------- 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index aecc3f71..a66df5b2 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -84,7 +84,7 @@ def run(args: argparse.Namespace) -> None: logging.info(f"MACE version: {mace.__version__}") except AttributeError: logging.info("Cannot find MACE version, please install MACE via pip") - logging.info(f"Configuration: {args}") + logging.debug(f"Configuration: {args}") tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) @@ -406,7 +406,7 @@ def run(args: argparse.Namespace) -> None: len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - logging.info(f"Hidden irreps: {args.hidden_irreps}") + logging.debug(f"Hidden irreps: {args.hidden_irreps}") model_config = dict( r_max=args.r_max, @@ -670,9 +670,9 @@ def run(args: argparse.Namespace) -> None: for group in optimizer.param_groups: group["lr"] = args.lr - logging.info(model) + logging.debug(model) logging.info(f"Number of parameters: {tools.count_parameters(model)}") - logging.info(f"Optimizer: {optimizer}") + logging.debug(f"Optimizer: {optimizer}") if args.wandb: logging.info("Using Weights and Biases for logging") diff --git a/mace/data/utils.py b/mace/data/utils.py index c870d6ed..07321bc3 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -203,7 +203,7 @@ def load_from_xyz( ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") if energy_key == "energy": - logging.info( + logging.warning( "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_energy'. You need to use --energy_key='REF_energy', to tell the key name chosen." ) energy_key = "REF_energy" @@ -211,10 +211,10 @@ def load_from_xyz( try: atoms.info["REF_energy"] = atoms.get_potential_energy() except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to extract energy: {e}") + logging.error(f"Failed to extract energy: {e}") atoms.info["REF_energy"] = None if forces_key == "forces": - logging.info( + logging.warning( "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_forces'. You need to use --forces_key='REF_forces', to tell the key name chosen." ) forces_key = "REF_forces" @@ -222,10 +222,10 @@ def load_from_xyz( try: atoms.arrays["REF_forces"] = atoms.get_forces() except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to extract forces: {e}") + logging.error(f"Failed to extract forces: {e}") atoms.arrays["REF_forces"] = None if stress_key == "stress": - logging.info( + logging.warning( "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_stress'. You need to use --stress_key='REF_stress', to tell the key name chosen." ) stress_key = "REF_stress" @@ -298,7 +298,7 @@ def compute_average_E0s( for i, z in enumerate(z_table.zs): atomic_energies_dict[z] = E0s[i] except np.linalg.LinAlgError: - logging.warning( + logging.error( "Failed to compute E0s using least squares regression, using the same for all atoms" ) atomic_energies_dict = {} diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index cc7b3929..39b053fe 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -128,10 +128,10 @@ def print_git_commit(): repo = git.Repo(search_parent_directories=True) commit = repo.head.commit.hexsha - logging.info(f"Current Git commit: {commit}") + logging.debug(f"Current Git commit: {commit}") return commit except Exception as e: # pylint: disable=W0703 - logging.info(f"Error accessing Git repository: {e}") + logging.debug(f"Error accessing Git repository: {e}") return "None" diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 65190108..762d9880 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -52,27 +52,42 @@ def setup_logger( directory: Optional[str] = None, rank: Optional[int] = 0, ): + # Create a logger logger = logging.getLogger() - logger.setLevel(level) + logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels + # Create formatters formatter = logging.Formatter( "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) + # Add filter for rank + logger.addFilter(lambda _: rank == 0) + + # Create console handler ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(level) ch.setFormatter(formatter) logger.addHandler(ch) - logger.addFilter(lambda _: (rank == 0)) - - if (directory is not None) and (tag is not None): + if directory is not None and tag is not None: os.makedirs(name=directory, exist_ok=True) - path = os.path.join(directory, tag + ".log") - fh = logging.FileHandler(path) - fh.setFormatter(formatter) - logger.addHandler(fh) + # Create file handler for non-debug logs + main_log_path = os.path.join(directory, f"{tag}.log") + fh_main = logging.FileHandler(main_log_path) + fh_main.setLevel(level) + fh_main.setFormatter(formatter) + logger.addHandler(fh_main) + + # Create file handler for debug logs + debug_log_path = os.path.join(directory, f"{tag}_debug.log") + fh_debug = logging.FileHandler(debug_log_path) + fh_debug.setLevel(logging.DEBUG) + fh_debug.setFormatter(formatter) + fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) + logger.addHandler(fh_debug) class AtomicNumberTable: From 6e2db59c385d63bf70482428215d2407b66c7d2a Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Mon, 22 Jul 2024 15:13:58 +0100 Subject: [PATCH 02/42] Average number of neighbours rounded to int. Warning added if number is suggestive of unusual data. --- mace/cli/run_train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index a66df5b2..3058a9f6 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -350,7 +350,10 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = (num_neighbors / num_graphs).item() else: args.avg_num_neighbors = avg_num_neighbors - logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") + if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: + logging.warning(f"Unusual average number of neighbors: {int(args.avg_num_neighbors)}") + else: + logging.info(f"Average number of neighbors: {int(args.avg_num_neighbors)}") # Selecting outputs compute_virials = False From 9e6161ddcb1260c0b1a1cfc6b34418f8a1ed30f4 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Mon, 22 Jul 2024 15:54:55 +0100 Subject: [PATCH 03/42] Changing default key for energy/forces/stress to REF_energy/REF_forces/REF_stress. Rephrased warning of using energy/force/stress as keys --- mace/data/utils.py | 24 ++++++++++++------------ mace/tools/arg_parser.py | 6 +++--- mace/tools/scripts_utils.py | 6 +++--- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index 07321bc3..a8f7b8dc 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -72,9 +72,9 @@ def random_train_valid_split( def config_from_atoms_list( atoms_list: List[ase.Atoms], - energy_key="energy", - forces_key="forces", - stress_key="stress", + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", virials_key="virials", dipole_key="dipole", charges_key="charges", @@ -103,9 +103,9 @@ def config_from_atoms_list( def config_from_atoms( atoms: ase.Atoms, - energy_key="energy", - forces_key="forces", - stress_key="stress", + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", virials_key="virials", dipole_key="dipole", charges_key="charges", @@ -192,9 +192,9 @@ def test_config_types( def load_from_xyz( file_path: str, config_type_weights: Dict, - energy_key: str = "energy", - forces_key: str = "forces", - stress_key: str = "stress", + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", @@ -204,7 +204,7 @@ def load_from_xyz( atoms_list = ase.io.read(file_path, index=":") if energy_key == "energy": logging.warning( - "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_energy'. You need to use --energy_key='REF_energy', to tell the key name chosen." + "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." ) energy_key = "REF_energy" for atoms in atoms_list: @@ -215,7 +215,7 @@ def load_from_xyz( atoms.info["REF_energy"] = None if forces_key == "forces": logging.warning( - "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_forces'. You need to use --forces_key='REF_forces', to tell the key name chosen." + "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." ) forces_key = "REF_forces" for atoms in atoms_list: @@ -226,7 +226,7 @@ def load_from_xyz( atoms.arrays["REF_forces"] = None if stress_key == "stress": logging.warning( - "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_stress'. You need to use --stress_key='REF_stress', to tell the key name chosen." + "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." ) stress_key = "REF_stress" for atoms in atoms_list: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 893203aa..83796052 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -334,13 +334,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--energy_key", help="Key of reference energies in training xyz", type=str, - default="energy", + default="REF_energy", ) parser.add_argument( "--forces_key", help="Key of reference forces in training xyz", type=str, - default="forces", + default="REF_forces", ) parser.add_argument( "--virials_key", @@ -352,7 +352,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--stress_key", help="Key of reference stress in training xyz", type=str, - default="stress", + default="REF_stress", ) parser.add_argument( "--dipole_key", diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 39b053fe..f1aa76e6 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -36,9 +36,9 @@ def get_dataset_from_xyz( test_path: str = None, seed: int = 1234, keep_isolated_atoms: bool = False, - energy_key: str = "energy", - forces_key: str = "forces", - stress_key: str = "stress", + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", virials_key: str = "virials", dipole_key: str = "dipoles", charges_key: str = "charges", From bf1a3a59c36d60e8de9f168ac24b17206913defc Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Mon, 22 Jul 2024 17:01:21 +0100 Subject: [PATCH 04/42] Epoch None loss will only appear in the debug log --- mace/cli/run_train.py | 2 +- mace/tools/train.py | 36 ++++++++++++++++++++---------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3058a9f6..36e6ad31 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -594,7 +594,7 @@ def run(args: argparse.Namespace) -> None: args.start_swa = max(1, args.max_num_epochs // 4 * 3) else: if args.start_swa > args.max_num_epochs: - logging.info( + logging.warning( f"Start swa must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" ) args.start_swa = max(1, args.max_num_epochs // 4 * 3) diff --git a/mace/tools/train.py b/mace/tools/train.py index 7ebf3ce1..74a5645b 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,11 +45,15 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) + if epoch is None: + logging_level=logging.DEBUG + else: + logging_level=logging.INFO if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -58,8 +62,8 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -68,38 +72,38 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" ) From deeb7bcc8573bbdb0adc9909134944a5e21b0b59 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Tue, 23 Jul 2024 14:04:37 +0100 Subject: [PATCH 05/42] Log number of energy and force values loaded with each data set --- mace/cli/run_train.py | 17 +++++++++-------- mace/tools/scripts_utils.py | 16 ++++++++++------ mace/tools/train.py | 4 ++++ 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 36e6ad31..c9b6c218 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -131,7 +131,9 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = statistics["avg_num_neighbors"] args.compute_avg_num_neighbors = False args.E0s = statistics["atomic_energies"] - + + logging.info("") + logging.info("===========LOADING INPUT DATA===========") # Data preparation if args.train_file.endswith(".xyz"): if args.valid_file is not None: @@ -154,11 +156,7 @@ def run(args: argparse.Namespace) -> None: charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) - - logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " - f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]" - ) + else: atomic_energies_dict = None @@ -286,7 +284,8 @@ def run(args: argparse.Namespace) -> None: num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), ) - + logging.info("") + logging.info("===========MODEL DETAILS===========") if args.loss == "weighted": loss_fn = modules.WeightedEnergyForcesLoss( energy_weight=args.energy_weight, forces_weight=args.forces_weight @@ -666,6 +665,7 @@ def run(args: argparse.Namespace) -> None: if opt_start_epoch is not None: start_epoch = opt_start_epoch + ema: Optional[ExponentialMovingAverage] = None if args.ema: ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) @@ -726,7 +726,8 @@ def run(args: argparse.Namespace) -> None: train_sampler=train_sampler, rank=rank, ) - + logging.info("") + logging.info("===========RESULTS===========") logging.info("Computing metrics for training, validation, and test sets") all_data_loaders = { diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index f1aa76e6..f2a61e13 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -57,7 +57,7 @@ def get_dataset_from_xyz( keep_isolated_atoms=keep_isolated_atoms, ) logging.info( - f"Loaded {len(all_train_configs)} training configurations from '{train_path}'" + f"Loaded {len(all_train_configs)} training configurations [{np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] from '{train_path}'" ) if valid_path is not None: _, valid_configs = data.load_from_xyz( @@ -72,16 +72,16 @@ def get_dataset_from_xyz( extract_atomic_energies=False, ) logging.info( - f"Loaded {len(valid_configs)} validation configurations from '{valid_path}'" + f"Loaded {len(valid_configs)} validation configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] from '{valid_path}'" ) train_configs = all_train_configs else: - logging.info( - "Using random %s%% of training set for validation", 100 * valid_fraction - ) train_configs, valid_configs = data.random_train_valid_split( all_train_configs, valid_fraction, seed ) + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation [{len(valid_configs)} configurations, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + ) test_configs = [] if test_path is not None: @@ -99,8 +99,12 @@ def get_dataset_from_xyz( # create list of tuples (config_type, list(Atoms)) test_configs = data.test_config_types(all_test_configs) logging.info( - f"Loaded {len(all_test_configs)} test configurations from '{test_path}'" + f"Loaded {len(all_test_configs)} test configurations from '{test_path}':" ) + logging.info( + f"{'; '.join([f'{name}: {len(test_configs)} configs, {np.sum([1 if config.energy else 0 for config in test_configs])} energy, {np.sum([config.forces.size for config in test_configs])} forces' for name, test_configs in test_configs])}" + ) + return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), atomic_energies_dict, diff --git a/mace/tools/train.py b/mace/tools/train.py index 74a5645b..5b74ecf1 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -141,8 +141,12 @@ def train( if log_wandb: import wandb + if max_grad_norm is not None: logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") + + logging.info("") + logging.info("===========TRAINING===========") logging.info("Started training") epoch = start_epoch From 1824dde0f0ceb9bf6a55107c11b9e5dcb720707f Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Tue, 30 Jul 2024 19:14:47 +0100 Subject: [PATCH 06/42] Set default eval_interval=1 to print loss at every epoch, rephrase print at epoch None --- mace/cli/preprocess_data.py | 2 +- mace/tools/arg_parser.py | 2 +- mace/tools/train.py | 9 +++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 5c198ec4..7aa11d94 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -211,7 +211,7 @@ def run(args: argparse.Namespace): atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") + logging.info(f"Atomic Energies: {atomic_energies.tolist()}") _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] avg_num_neighbors, mean, std=pool_compute_stats(_inputs) logging.info(f"Average number of neighbors: {avg_num_neighbors}") diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 83796052..7277d2b4 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -534,7 +534,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=True, ) parser.add_argument( - "--eval_interval", help="evaluate model every epochs", type=int, default=2 + "--eval_interval", help="evaluate model every epochs", type=int, default=1 ) parser.add_argument( "--keep_checkpoints", diff --git a/mace/tools/train.py b/mace/tools/train.py index 5b74ecf1..cd04f010 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,15 +45,16 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) + logging_level=logging.INFO if epoch is None: - logging_level=logging.DEBUG + initial_phrase = "Initial loss on validation set:" else: - logging_level=logging.INFO + initial_phrase = f"Epoch {epoch}:" if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + msg=f"{initial_phrase} loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -147,7 +148,7 @@ def train( logging.info("") logging.info("===========TRAINING===========") - logging.info("Started training") + logging.info("Started training, reporting errors on validation set") epoch = start_epoch # # log validation loss before _any_ training From 1f4e7eac32f80f9131b22a182edcf0903e687bc8 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Tue, 30 Jul 2024 22:31:01 +0100 Subject: [PATCH 07/42] Made Epoch None vs Epoch 0 message more concise --- mace/tools/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index cd04f010..bbe7b616 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -46,10 +46,8 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["epoch"] = epoch logger.log(eval_metrics) logging_level=logging.INFO - if epoch is None: - initial_phrase = "Initial loss on validation set:" - else: - initial_phrase = f"Epoch {epoch}:" + initial_phrase = "Initial loss on validation set:" if epoch is None else f"Epoch {epoch}:" + if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 From 13af3f11f6e4a61fce55bdca9c9120fab02e94ad Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Wed, 31 Jul 2024 18:20:52 +0100 Subject: [PATCH 08/42] Change Atomic Numbers and Energies loggings --- mace/cli/run_train.py | 11 ++++++----- mace/tools/scripts_utils.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c9b6c218..950e2c96 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -179,12 +179,11 @@ def run(args: argparse.Namespace) -> None: assert isinstance(zs_list, list) z_table = tools.get_atomic_number_table_from_zs(zs_list) # yapf: enable - logging.info(z_table) + logging.info(f"Atomic Numbers used: {z_table.zs}") if atomic_energies_dict is None or len(atomic_energies_dict) == 0: if args.E0s.lower() == "foundation": assert args.foundation_model is not None - logging.info("Using atomic energies from foundation model") z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) @@ -194,6 +193,7 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } + logging.info(f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}") else: if args.train_file.endswith(".xyz"): atomic_energies_dict = get_atomic_energies( @@ -224,8 +224,8 @@ def run(args: argparse.Namespace) -> None: atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") - + logging.info(f"Atomic Energies used [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}") + if args.train_file.endswith(".xyz"): train_set = [ data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) @@ -368,7 +368,8 @@ def run(args: argparse.Namespace) -> None: "stress": args.compute_stress, "dipoles": compute_dipole, } - logging.info(f"Selected the following outputs: {output_args}") + + logging.info(f"Selected the following values to use and report: {[report for report, value in output_args.items() if value]}") if args.scaling == "no_scaling": args.std = 1.0 diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index f2a61e13..776c0909 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -291,7 +291,7 @@ def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: def get_atomic_energies(E0s, train_collection, z_table) -> dict: if E0s is not None: logging.info( - "Atomic Energies not in training file, using command line argument E0s" + "Isolated Atomic Energies (E0s) not in training file, using command line argument" ) if E0s.lower() == "average": logging.info( From e5933ab4ef9cc1cbe3dc5caebb0ce09418477373 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Wed, 31 Jul 2024 18:35:33 +0100 Subject: [PATCH 09/42] Change precision of average number of neighbours --- mace/cli/run_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 950e2c96..d385ff76 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -350,9 +350,9 @@ def run(args: argparse.Namespace) -> None: else: args.avg_num_neighbors = avg_num_neighbors if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: - logging.warning(f"Unusual average number of neighbors: {int(args.avg_num_neighbors)}") + logging.warning(f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}") else: - logging.info(f"Average number of neighbors: {int(args.avg_num_neighbors)}") + logging.info(f"Average number of neighbors: {args.avg_num_neighbors:.1f}") # Selecting outputs compute_virials = False From 09b67e111ac9f0a465a9700c1d4a41430cd2a36d Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Wed, 31 Jul 2024 18:37:14 +0100 Subject: [PATCH 10/42] Change stress, dipole and charges default keys to REF_stress, REF_dipole and REF_charges --- mace/data/utils.py | 18 +++++++++--------- mace/tools/arg_parser.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index a8f7b8dc..66020d52 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -75,9 +75,9 @@ def config_from_atoms_list( energy_key="REF_energy", forces_key="REF_forces", stress_key="REF_stress", - virials_key="virials", - dipole_key="dipole", - charges_key="charges", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", config_type_weights: Dict[str, float] = None, ) -> Configurations: """Convert list of ase.Atoms into Configurations""" @@ -106,9 +106,9 @@ def config_from_atoms( energy_key="REF_energy", forces_key="REF_forces", stress_key="REF_stress", - virials_key="virials", - dipole_key="dipole", - charges_key="charges", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", config_type_weights: Dict[str, float] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" @@ -195,9 +195,9 @@ def load_from_xyz( energy_key: str = "REF_energy", forces_key: str = "REF_forces", stress_key: str = "REF_stress", - virials_key: str = "virials", - dipole_key: str = "dipole", - charges_key: str = "charges", + virials_key: str = "REF_virials", + dipole_key: str = "REF_dipole", + charges_key: str = "REF_charges", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 7277d2b4..71d9036e 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -346,7 +346,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--virials_key", help="Key of reference virials in training xyz", type=str, - default="virials", + default="REF_virials", ) parser.add_argument( "--stress_key", @@ -358,13 +358,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--dipole_key", help="Key of reference dipoles in training xyz", type=str, - default="dipole", + default="REF_dipole", ) parser.add_argument( "--charges_key", help="Key of atomic charges in training xyz", type=str, - default="charges", + default="REF_charges", ) # Loss and optimization @@ -674,37 +674,37 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: "--energy_key", help="Key of reference energies in training xyz", type=str, - default="energy", + default="REF_energy", ) parser.add_argument( "--forces_key", help="Key of reference forces in training xyz", type=str, - default="forces", + default="REF_forces", ) parser.add_argument( "--virials_key", help="Key of reference virials in training xyz", type=str, - default="virials", + default="REF_virials", ) parser.add_argument( "--stress_key", help="Key of reference stress in training xyz", type=str, - default="stress", + default="REF_stress", ) parser.add_argument( "--dipole_key", help="Key of reference dipoles in training xyz", type=str, - default="dipole", + default="REF_dipole", ) parser.add_argument( "--charges_key", help="Key of atomic charges in training xyz", type=str, - default="charges", + default="REF_charges", ) parser.add_argument( "--atomic_numbers", From 21444fcfc0a10353aeeae65232e55d71bd9a1871 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Thu, 15 Aug 2024 16:11:02 +0200 Subject: [PATCH 11/42] Created check_args() to allow checking the input flag values' sensibility, flag any inconsistency and fix them, such as using all three of hidden_irreps, num_channels and max_L --- mace/cli/run_train.py | 18 +++++++++++++++--- mace/tools/__init__.py | 3 ++- mace/tools/arg_parser.py | 17 ++++++++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index d385ff76..94f62d3f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -55,6 +55,7 @@ def run(args: argparse.Namespace) -> None: """ This script runs the training/fine tuning for mace """ + args, input_log_messages = tools.check_args(args) tag = tools.get_tag(name=args.name, seed=args.seed) if args.distributed: try: @@ -74,6 +75,16 @@ def run(args: argparse.Namespace) -> None: # Setup tools.set_seeds(args.seed) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) + logging.info("===========CHECKING SETTINGS===========") + for message, level in input_log_messages: + if level == "debug": + logging.debug(message) + elif level == "warning": + logging.warning(message) + elif level == "error": + logging.error(message) + else: + logging.info(message) if args.distributed: torch.cuda.set_device(local_rank) @@ -380,7 +391,7 @@ def run(args: argparse.Namespace) -> None: ) # Build model if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Building model") + logging.debug("Building model") model_config_foundation = extract_config_mace_model(model_foundation) model_config_foundation["atomic_numbers"] = z_table.zs model_config_foundation["num_elements"] = len(z_table) @@ -395,7 +406,7 @@ def run(args: argparse.Namespace) -> None: args.model = "FoundationMACE" model_config = model_config_foundation # pylint else: - logging.info("Building model") + logging.debug("Building model") if args.num_channels is not None and args.max_L is not None: assert args.num_channels > 0, "num_channels must be positive integer" assert args.max_L >= 0, "max_L must be non-negative integer" @@ -409,7 +420,8 @@ def run(args: argparse.Namespace) -> None: len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - logging.debug(f"Hidden irreps: {args.hidden_irreps}") + logging.info(f"Hidden irreps: {args.hidden_irreps} (Number of channel: {args.num_channels}, max_L: {args.max_L})") + model_config = dict( r_max=args.r_max, diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 80375590..a1b10cb7 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,4 +1,4 @@ -from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser import build_default_arg_parser, check_args, build_preprocess_arg_parser from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .finetuning_utils import load_foundations @@ -39,6 +39,7 @@ "to_numpy", "to_one_hot", "build_default_arg_parser", + "check_args", "set_seeds", "init_device", "setup_logger", diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 71d9036e..4853fc70 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -7,8 +7,23 @@ import argparse import os from typing import Optional +from e3nn import o3 +def check_args(args): + """ + Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing + the (potentially) modified args and a list of log messages. + """ + log_messages = [] + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif args.hidden_irreps is not None and args.num_channels is not None and args.max_L is not None: + args.hidden_irreps = o3.Irreps((args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)).sort().irreps.simplify()) + log_messages.append((f"Both hidden_irreps, num_channels and max_L are specified. Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.","info")) + return args, log_messages + def build_default_arg_parser() -> argparse.ArgumentParser: try: import configargparse @@ -183,7 +198,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--hidden_irreps", help="irreps for hidden node states", type=str, - default="128x0e + 128x1o", + default=None, ) # add option to specify irreps by channel number and max L parser.add_argument( From bb7fd9ffa4d620f7c1db3428a0b0fdae8146b8dc Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Thu, 15 Aug 2024 18:42:42 +0200 Subject: [PATCH 12/42] Creating work_dir to allow simultaneous setting of all directories. If any of the subdirectories is specified, that will overwrite the place of that dir. --- mace/tools/arg_parser.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 4853fc70..5aaf048e 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -22,6 +22,19 @@ def check_args(args): args.hidden_irreps = o3.Irreps((args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)).sort().irreps.simplify()) log_messages.append((f"Both hidden_irreps, num_channels and max_L are specified. Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.","info")) + # Use work_dir for all other directories as well, unless they were specified by the user + if args.work_dir != ".": + if args.log_dir ==None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir == None: + args.model_dir = args.work_dir + if args.checkpoints_dir == None : + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir == None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir == None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") + return args, log_messages def build_default_arg_parser() -> argparse.ArgumentParser: @@ -46,22 +59,22 @@ def build_default_arg_parser() -> argparse.ArgumentParser: # Directories parser.add_argument( - "--log_dir", help="directory for log files", type=str, default="logs" + "--work_dir", help="set directory for all files and folders", type=str, default="." ) parser.add_argument( - "--model_dir", help="directory for final model", type=str, default="." + "--log_dir", help="directory for log files", type=str, default=None ) parser.add_argument( - "--checkpoints_dir", - help="directory for checkpoint files", - type=str, - default="checkpoints", + "--model_dir", help="directory for final model", type=str, default=None + ) + parser.add_argument( + "--checkpoints_dir", help="directory for checkpoint files", type=str, default=None ) parser.add_argument( - "--results_dir", help="directory for results", type=str, default="results" + "--results_dir", help="directory for results", type=str, default=None ) parser.add_argument( - "--downloads_dir", help="directory for downloads", type=str, default="downloads" + "--downloads_dir", help="directory for downloads", type=str, default=None ) # Device and logging From b7a648275ac0f4504f2789e9e771544c8dff2103 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Thu, 15 Aug 2024 19:06:38 +0200 Subject: [PATCH 13/42] Correct dir_* paths assignment, set all dirs to correspond to work_dir --- mace/tools/arg_parser.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 5aaf048e..556fc719 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -23,17 +23,16 @@ def check_args(args): log_messages.append((f"Both hidden_irreps, num_channels and max_L are specified. Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.","info")) # Use work_dir for all other directories as well, unless they were specified by the user - if args.work_dir != ".": - if args.log_dir ==None: - args.log_dir = os.path.join(args.work_dir, "logs") - if args.model_dir == None: - args.model_dir = args.work_dir - if args.checkpoints_dir == None : - args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") - if args.results_dir == None: - args.results_dir = os.path.join(args.work_dir, "results") - if args.downloads_dir == None: - args.downloads_dir = os.path.join(args.work_dir, "downloads") + if args.log_dir ==None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir == None: + args.model_dir = args.work_dir + if args.checkpoints_dir == None : + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir == None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir == None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") return args, log_messages From 3386b1e5fc015ef52260b98c05cb3873d89ce8c2 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Thu, 15 Aug 2024 19:28:35 +0200 Subject: [PATCH 14/42] Warn and change batch size value(s) when larger than the number of training/validation data --- mace/cli/run_train.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 94f62d3f..1c970798 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -170,6 +170,15 @@ def run(args: argparse.Namespace) -> None: else: atomic_energies_dict = None + + if len(collections.train) Date: Thu, 15 Aug 2024 22:01:32 +0200 Subject: [PATCH 15/42] Added details about model and optimizer settings --- mace/cli/run_train.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1c970798..5f96732e 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -696,9 +696,21 @@ def run(args: argparse.Namespace) -> None: group["lr"] = args.lr logging.debug(model) - logging.info(f"Number of parameters: {tools.count_parameters(model)}") - logging.debug(f"Optimizer: {optimizer}") + logging.info(f"Total number of parameters: {tools.count_parameters(model)}") + logging.info(f"Batch size: {args.batch_size}, validation batch size: {args.valid_batch_size}") + logging.info(f"Number of gradient updates: {args.max_num_epochs*len(collections.train)/args.batch_size}") + logging.info(f"Radial cutoff: {args.r_max}, num_radial_basis: {args.num_radial_basis}, num_cutoff_basis: {args.num_cutoff_basis}") + logging.info(f"Polynomial cutoff: {args.num_cutoff_basis}, max_L: {args.max_L}, num_interactions: {args.num_interactions}") + logging.info(f"Correlation: {args.correlation}, distance transform: {args.distance_transform}") + logging.info("") + logging.info("===========OPTIMIZER INFORMATION===========") + logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") + logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.debug( + f"{'\n '.join([f'{group["name"]}: learning rate: {group["lr"]} and weight decay: {group["weight_decay"]}' for group in optimizer.param_groups])}" + ) + if args.wandb: logging.info("Using Weights and Biases for logging") import wandb From 0dc98b2c04cc82e0af3dedcc49fc8f12885132d6 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Fri, 16 Aug 2024 11:29:29 +0200 Subject: [PATCH 16/42] Check start_stage_two if larger than max_num_epoch and won't start stage two if true. Rearranged check_args to match structuring of arg parser. --- mace/cli/run_train.py | 7 ------- mace/tools/arg_parser.py | 25 ++++++++++++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 90bec3c2..e46eab51 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -616,13 +616,6 @@ def run(args: argparse.Namespace) -> None: swas.append(True) if args.start_swa is None: args.start_swa = max(1, args.max_num_epochs // 4 * 3) - else: - if args.start_swa > args.max_num_epochs: - logging.info( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" - ) - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - logging.info(f"Setting start Stage Two to {args.start_swa}") if args.loss == "forces_only": raise ValueError("Can not select Stage Two with forces only loss.") if args.loss == "virials": diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 88f3991b..977c048a 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -15,13 +15,8 @@ def check_args(args): the (potentially) modified args and a list of log messages. """ log_messages = [] - # Check if hidden_irreps, num_channels and max_L are consistent - if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: - args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 - elif args.hidden_irreps is not None and args.num_channels is not None and args.max_L is not None: - args.hidden_irreps = o3.Irreps((args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)).sort().irreps.simplify()) - log_messages.append((f"Both hidden_irreps, num_channels and max_L are specified. Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.","info")) - + + # Directories # Use work_dir for all other directories as well, unless they were specified by the user if args.log_dir ==None: args.log_dir = os.path.join(args.work_dir, "logs") @@ -33,6 +28,22 @@ def check_args(args): args.results_dir = os.path.join(args.work_dir, "results") if args.downloads_dir == None: args.downloads_dir = os.path.join(args.work_dir, "downloads") + + # Model + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif args.hidden_irreps is not None and args.num_channels is not None and args.max_L is not None: + args.hidden_irreps = o3.Irreps((args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)).sort().irreps.simplify()) + log_messages.append((f"Both hidden_irreps, num_channels and max_L are specified.","info")) + log_messages.append((f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.","warning")) + + # Loss and optimization + # Check Stage Two loss start + if args.start_swa > args.max_num_epochs: + log_messages.append(( f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", "info")) + log_messages.append(( f"Stage Two will not start", "warning")) + args.swa = None return args, log_messages From 21ee011bae42bfef12ed5617d4c7d1984038ccda Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Fri, 16 Aug 2024 12:49:41 +0200 Subject: [PATCH 17/42] Separate test and train error-table --- mace/cli/run_train.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index e46eab51..cd65a8dc 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -812,6 +812,9 @@ def run(args: argparse.Namespace) -> None: ) all_data_loaders[test_name] = test_loader + train_valid_data_loader = {k: v for k, v in all_data_loaders.items() if k in ["train", "valid"]} + test_data_loader = {k: v for k, v in all_data_loaders.items() if k not in ["train", "valid"]} + for swa_eval in swas: epoch = checkpoint_handler.load_latest( state=tools.CheckpointState(model, optimizer, lr_scheduler), @@ -822,13 +825,27 @@ def run(args: argparse.Namespace) -> None: if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) model_to_evaluate = model if not args.distributed else distributed_model - logging.info(f"Loaded model from epoch {epoch}") + if swa_eval: + logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") + else: + logging.info(f"Loaded model from epoch {epoch} for evaluation") for param in model.parameters(): param.requires_grad = False - table = create_error_table( + + table_train = create_error_table( + table_type=args.error_table, + all_data_loaders=train_valid_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + table_test = create_error_table( table_type=args.error_table, - all_data_loaders=all_data_loaders, + all_data_loaders=test_data_loader, model=model_to_evaluate, loss_fn=loss_fn, output_args=output_args, @@ -836,7 +853,9 @@ def run(args: argparse.Namespace) -> None: device=device, distributed=args.distributed, ) - logging.info("\n" + str(table)) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train)) + logging.info("Error-table on TEST:\n" + str(table_test)) + if rank == 0: # Save entire model From 8ecfa97dc99fca0bfc3a970ecadc2915ffe9a381 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Fri, 16 Aug 2024 12:58:54 +0200 Subject: [PATCH 18/42] deleted f-string expression backslash --- mace/cli/run_train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index cd65a8dc..8dc02699 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -703,9 +703,6 @@ def run(args: argparse.Namespace) -> None: logging.info("===========OPTIMIZER INFORMATION===========") logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") - logging.debug( - f"{'\n '.join([f'{group["name"]}: learning rate: {group["lr"]} and weight decay: {group["weight_decay"]}' for group in optimizer.param_groups])}" - ) if args.wandb: logging.info("Using Weights and Biases for logging") From 7bcf55e67da97b305bdadda1e4699d6618ee4031 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Fri, 16 Aug 2024 13:23:36 +0200 Subject: [PATCH 19/42] Moved batch size checks compared to data size right after data set is loaded --- mace/cli/run_train.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8dc02699..168b18e9 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -167,18 +167,18 @@ def run(args: argparse.Namespace) -> None: charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) + if len(collections.train) Date: Fri, 16 Aug 2024 14:01:41 +0200 Subject: [PATCH 20/42] Revert "Moved batch size checks compared to data size right after data set is loaded" This reverts commit 7bcf55e67da97b305bdadda1e4699d6618ee4031. --- mace/cli/run_train.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 168b18e9..8dc02699 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -167,18 +167,18 @@ def run(args: argparse.Namespace) -> None: charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) - if len(collections.train) Date: Fri, 16 Aug 2024 14:01:55 +0200 Subject: [PATCH 21/42] Revert "deleted f-string expression backslash" This reverts commit 8ecfa97dc99fca0bfc3a970ecadc2915ffe9a381. --- mace/cli/run_train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8dc02699..cd65a8dc 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -703,6 +703,9 @@ def run(args: argparse.Namespace) -> None: logging.info("===========OPTIMIZER INFORMATION===========") logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.debug( + f"{'\n '.join([f'{group["name"]}: learning rate: {group["lr"]} and weight decay: {group["weight_decay"]}' for group in optimizer.param_groups])}" + ) if args.wandb: logging.info("Using Weights and Biases for logging") From 54c5308f833f89f7b88e21b7fa9c72e9d6981f7e Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Fri, 16 Aug 2024 14:02:05 +0200 Subject: [PATCH 22/42] Revert "Separate test and train error-table" This reverts commit 21ee011bae42bfef12ed5617d4c7d1984038ccda. --- mace/cli/run_train.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index cd65a8dc..e46eab51 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -812,9 +812,6 @@ def run(args: argparse.Namespace) -> None: ) all_data_loaders[test_name] = test_loader - train_valid_data_loader = {k: v for k, v in all_data_loaders.items() if k in ["train", "valid"]} - test_data_loader = {k: v for k, v in all_data_loaders.items() if k not in ["train", "valid"]} - for swa_eval in swas: epoch = checkpoint_handler.load_latest( state=tools.CheckpointState(model, optimizer, lr_scheduler), @@ -825,27 +822,13 @@ def run(args: argparse.Namespace) -> None: if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) model_to_evaluate = model if not args.distributed else distributed_model - if swa_eval: - logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") - else: - logging.info(f"Loaded model from epoch {epoch} for evaluation") + logging.info(f"Loaded model from epoch {epoch}") for param in model.parameters(): param.requires_grad = False - - table_train = create_error_table( - table_type=args.error_table, - all_data_loaders=train_valid_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - ) - table_test = create_error_table( + table = create_error_table( table_type=args.error_table, - all_data_loaders=test_data_loader, + all_data_loaders=all_data_loaders, model=model_to_evaluate, loss_fn=loss_fn, output_args=output_args, @@ -853,9 +836,7 @@ def run(args: argparse.Namespace) -> None: device=device, distributed=args.distributed, ) - logging.info("Error-table on TRAIN and VALID:\n" + str(table_train)) - logging.info("Error-table on TEST:\n" + str(table_test)) - + logging.info("\n" + str(table)) if rank == 0: # Save entire model From 78d2fcd3c4c8573ecbaba51bc245ce242434dbfd Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Fri, 16 Aug 2024 17:53:55 +0200 Subject: [PATCH 23/42] deleted f-string expression backslash --- mace/cli/run_train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index e46eab51..2fe062aa 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -703,9 +703,6 @@ def run(args: argparse.Namespace) -> None: logging.info("===========OPTIMIZER INFORMATION===========") logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") - logging.debug( - f"{'\n '.join([f'{group["name"]}: learning rate: {group["lr"]} and weight decay: {group["weight_decay"]}' for group in optimizer.param_groups])}" - ) if args.wandb: logging.info("Using Weights and Biases for logging") From 7511038528f096865a93fb7f7f7a6d98fc06b464 Mon Sep 17 00:00:00 2001 From: Eszter Varga-Umbrich Date: Fri, 16 Aug 2024 18:27:55 +0200 Subject: [PATCH 24/42] Apply pre-commit hook changes and fix linting issues --- mace/cli/run_train.py | 101 ++++++++++++++++++++++++++++----------- mace/tools/__init__.py | 6 ++- mace/tools/arg_parser.py | 67 ++++++++++++++++++-------- mace/tools/train.py | 49 +++++++++++-------- 4 files changed, 154 insertions(+), 69 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 2fe062aa..7dc6285c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -142,7 +142,7 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = statistics["avg_num_neighbors"] args.compute_avg_num_neighbors = False args.E0s = statistics["atomic_energies"] - + logging.info("") logging.info("===========LOADING INPUT DATA===========") # Data preparation @@ -167,18 +167,21 @@ def run(args: argparse.Namespace) -> None: charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) - + if len(collections.train) < args.batch_size: + logging.warning( + f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" + ) + args.batch_size = int(len(collections.train) * 0.1) + logging.warning(f"Batch size changed to {args.batch_size}") + if len(collections.train) < len(collections.valid): + logging.warning( + f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" + ) + args.valid_batch_size = int(len(collections.valid) * 0.1) + logging.warning(f"Validation batch size changed to {args.valid_batch_size}") + else: atomic_energies_dict = None - - if len(collections.train) None: ].item() for z in z_table.zs } - logging.info(f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}") + logging.info( + f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" + ) else: if args.train_file.endswith(".xyz"): atomic_energies_dict = get_atomic_energies( @@ -244,8 +249,10 @@ def run(args: argparse.Namespace) -> None: atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) - logging.info(f"Atomic Energies used [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}") - + logging.info( + f"Atomic Energies used [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}" + ) + if args.train_file.endswith(".xyz"): train_set = [ data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) @@ -370,7 +377,9 @@ def run(args: argparse.Namespace) -> None: else: args.avg_num_neighbors = avg_num_neighbors if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: - logging.warning(f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}") + logging.warning( + f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}" + ) else: logging.info(f"Average number of neighbors: {args.avg_num_neighbors:.1f}") @@ -392,7 +401,9 @@ def run(args: argparse.Namespace) -> None: "dipoles": compute_dipole, } - logging.info(f"Selected the following values to use and report: {[report for report, value in output_args.items() if value]}") + logging.info( + f"Selected the following values to use and report: {[report for report, value in output_args.items() if value]}" + ) if args.scaling == "no_scaling": args.std = 1.0 @@ -432,8 +443,9 @@ def run(args: argparse.Namespace) -> None: len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - logging.info(f"Hidden irreps: {args.hidden_irreps} (Number of channel: {args.num_channels}, max_L: {args.max_L})") - + logging.info( + f"Hidden irreps: {args.hidden_irreps} (Number of channel: {args.num_channels}, max_L: {args.max_L})" + ) model_config = dict( r_max=args.r_max, @@ -683,7 +695,6 @@ def run(args: argparse.Namespace) -> None: if opt_start_epoch is not None: start_epoch = opt_start_epoch - ema: Optional[ExponentialMovingAverage] = None if args.ema: ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) @@ -693,17 +704,27 @@ def run(args: argparse.Namespace) -> None: logging.debug(model) logging.info(f"Total number of parameters: {tools.count_parameters(model)}") - logging.info(f"Batch size: {args.batch_size}, validation batch size: {args.valid_batch_size}") - logging.info(f"Number of gradient updates: {args.max_num_epochs*len(collections.train)/args.batch_size}") - logging.info(f"Radial cutoff: {args.r_max}, num_radial_basis: {args.num_radial_basis}, num_cutoff_basis: {args.num_cutoff_basis}") - logging.info(f"Polynomial cutoff: {args.num_cutoff_basis}, max_L: {args.max_L}, num_interactions: {args.num_interactions}") - logging.info(f"Correlation: {args.correlation}, distance transform: {args.distance_transform}") + logging.info( + f"Batch size: {args.batch_size}, validation batch size: {args.valid_batch_size}" + ) + logging.info( + f"Number of gradient updates: {args.max_num_epochs*len(collections.train)/args.batch_size}" + ) + logging.info( + f"Radial cutoff: {args.r_max}, num_radial_basis: {args.num_radial_basis}, num_cutoff_basis: {args.num_cutoff_basis}" + ) + logging.info( + f"Polynomial cutoff: {args.num_cutoff_basis}, max_L: {args.max_L}, num_interactions: {args.num_interactions}" + ) + logging.info( + f"Correlation: {args.correlation}, distance transform: {args.distance_transform}" + ) logging.info("") logging.info("===========OPTIMIZER INFORMATION===========") logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") - + if args.wandb: logging.info("Using Weights and Biases for logging") import wandb @@ -809,6 +830,13 @@ def run(args: argparse.Namespace) -> None: ) all_data_loaders[test_name] = test_loader + train_valid_data_loader = { + k: v for k, v in all_data_loaders.items() if k in ["train", "valid"] + } + test_data_loader = { + k: v for k, v in all_data_loaders.items() if k not in ["train", "valid"] + } + for swa_eval in swas: epoch = checkpoint_handler.load_latest( state=tools.CheckpointState(model, optimizer, lr_scheduler), @@ -819,13 +847,27 @@ def run(args: argparse.Namespace) -> None: if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) model_to_evaluate = model if not args.distributed else distributed_model - logging.info(f"Loaded model from epoch {epoch}") + if swa_eval: + logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") + else: + logging.info(f"Loaded model from epoch {epoch} for evaluation") for param in model.parameters(): param.requires_grad = False - table = create_error_table( + + table_train = create_error_table( + table_type=args.error_table, + all_data_loaders=train_valid_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + table_test = create_error_table( table_type=args.error_table, - all_data_loaders=all_data_loaders, + all_data_loaders=test_data_loader, model=model_to_evaluate, loss_fn=loss_fn, output_args=output_args, @@ -833,7 +875,8 @@ def run(args: argparse.Namespace) -> None: device=device, distributed=args.distributed, ) - logging.info("\n" + str(table)) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train)) + logging.info("Error-table on TEST:\n" + str(table_test)) if rank == 0: # Save entire model diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index a1b10cb7..1234cf45 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,4 +1,8 @@ -from .arg_parser import build_default_arg_parser, check_args, build_preprocess_arg_parser +from .arg_parser import ( + build_default_arg_parser, + build_preprocess_arg_parser, + check_args, +) from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .finetuning_utils import load_foundations diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 977c048a..56f4aeeb 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -7,46 +7,69 @@ import argparse import os from typing import Optional + from e3nn import o3 + def check_args(args): """ Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing the (potentially) modified args and a list of log messages. """ log_messages = [] - + # Directories # Use work_dir for all other directories as well, unless they were specified by the user - if args.log_dir ==None: - args.log_dir = os.path.join(args.work_dir, "logs") - if args.model_dir == None: + if args.log_dir is None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir is None: args.model_dir = args.work_dir - if args.checkpoints_dir == None : + if args.checkpoints_dir is None: args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") - if args.results_dir == None: - args.results_dir = os.path.join(args.work_dir, "results") - if args.downloads_dir == None: + if args.results_dir is None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir is None: args.downloads_dir = os.path.join(args.work_dir, "downloads") # Model # Check if hidden_irreps, num_channels and max_L are consistent if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: - args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 - elif args.hidden_irreps is not None and args.num_channels is not None and args.max_L is not None: - args.hidden_irreps = o3.Irreps((args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)).sort().irreps.simplify()) - log_messages.append((f"Both hidden_irreps, num_channels and max_L are specified.","info")) - log_messages.append((f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.","warning")) + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif ( + args.hidden_irreps is not None + and args.num_channels is not None + and args.max_L is not None + ): + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + log_messages.append( + ("Both hidden_irreps, num_channels and max_L are specified", "info") + ) + log_messages.append( + ( + f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.", + "warning", + ) + ) # Loss and optimization # Check Stage Two loss start if args.start_swa > args.max_num_epochs: - log_messages.append(( f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", "info")) - log_messages.append(( f"Stage Two will not start", "warning")) + log_messages.append( + ( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + "info", + ) + ) + log_messages.append(("Stage Two will not start", "warning")) args.swa = None - + return args, log_messages - + + def build_default_arg_parser() -> argparse.ArgumentParser: try: import configargparse @@ -69,7 +92,10 @@ def build_default_arg_parser() -> argparse.ArgumentParser: # Directories parser.add_argument( - "--work_dir", help="set directory for all files and folders", type=str, default="." + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", ) parser.add_argument( "--log_dir", help="directory for log files", type=str, default=None @@ -78,7 +104,10 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--model_dir", help="directory for final model", type=str, default=None ) parser.add_argument( - "--checkpoints_dir", help="directory for checkpoint files", type=str, default=None + "--checkpoints_dir", + help="directory for checkpoint files", + type=str, + default=None, ) parser.add_argument( "--results_dir", help="directory for results", type=str, default=None diff --git a/mace/tools/train.py b/mace/tools/train.py index c7c17ff2..42e2151b 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,14 +45,17 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) - logging_level=logging.INFO - initial_phrase = "Initial loss on validation set:" if epoch is None else f"Epoch {epoch}:" - + logging_level = logging.INFO + initial_phrase = ( + "Initial loss on validation set:" if epoch is None else f"Epoch {epoch}:" + ) + if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.log(level=logging_level, - msg=f"{initial_phrase} loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + logging.log( + level=logging_level, + msg=f"{initial_phrase} loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -61,8 +64,9 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" + logging.log( + level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -71,8 +75,9 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" + logging.log( + level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -97,32 +102,37 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + logging.log( + level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + logging.log( + level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + logging.log( + level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" + logging.log( + level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" + logging.log( + level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye", ) @@ -160,7 +170,6 @@ def train( if log_wandb: import wandb - if max_grad_norm is not None: logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") From 651493f9c6ccdd5a8b2e87b5e118b6be1340a316 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Sun, 18 Aug 2024 21:39:35 +0200 Subject: [PATCH 25/42] Revert "Set default eval_interval=1 to print loss at every epoch, rephrase print at epoch None" This reverts commit 1824dde0f0ceb9bf6a55107c11b9e5dcb720707f. --- mace/cli/preprocess_data.py | 2 +- mace/tools/arg_parser.py | 2 +- mace/tools/train.py | 16 +++++++--------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 7aa11d94..5c198ec4 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -211,7 +211,7 @@ def run(args: argparse.Namespace): atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) - logging.info(f"Atomic Energies: {atomic_energies.tolist()}") + logging.info(f"Atomic energies: {atomic_energies.tolist()}") _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] avg_num_neighbors, mean, std=pool_compute_stats(_inputs) logging.info(f"Average number of neighbors: {avg_num_neighbors}") diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 56f4aeeb..785297c6 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -621,7 +621,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=True, ) parser.add_argument( - "--eval_interval", help="evaluate model every epochs", type=int, default=1 + "--eval_interval", help="evaluate model every epochs", type=int, default=2 ) parser.add_argument( "--keep_checkpoints", diff --git a/mace/tools/train.py b/mace/tools/train.py index 42e2151b..0e4cfdeb 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,17 +45,15 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) - logging_level = logging.INFO - initial_phrase = ( - "Initial loss on validation set:" if epoch is None else f"Epoch {epoch}:" - ) - + if epoch is None: + logging_level=logging.DEBUG + else: + logging_level=logging.INFO if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.log( - level=logging_level, - msg=f"{initial_phrase} loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", + logging.log(level=logging_level, + msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -175,7 +173,7 @@ def train( logging.info("") logging.info("===========TRAINING===========") - logging.info("Started training, reporting errors on validation set") + logging.info("Started training") epoch = start_epoch # # log validation loss before _any_ training From 9b9cb2487f662cf9bcff79057c1b0f753512c96d Mon Sep 17 00:00:00 2001 From: vue1999 Date: Sun, 18 Aug 2024 21:45:51 +0200 Subject: [PATCH 26/42] Only check if start_swa is smaller than max_num_epoch if swa is enabled --- mace/tools/arg_parser.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 785297c6..0b8f5ec3 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -57,15 +57,18 @@ def check_args(args): # Loss and optimization # Check Stage Two loss start - if args.start_swa > args.max_num_epochs: - log_messages.append( - ( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", - "info", + if args.swa: + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + if args.start_swa > args.max_num_epochs: + log_messages.append( + ( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + "info", + ) ) - ) - log_messages.append(("Stage Two will not start", "warning")) - args.swa = None + log_messages.append(("Stage Two will not start", "warning")) + args.swa = False return args, log_messages From 884adca87cdaaa50a2ec0d2ecff4ff50673a3c53 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Sun, 18 Aug 2024 22:34:47 +0200 Subject: [PATCH 27/42] Small changes on logging informations --- mace/cli/preprocess_data.py | 2 +- mace/cli/run_train.py | 11 ++-------- mace/tools/arg_parser.py | 10 +++++---- mace/tools/train.py | 44 ++++++++++++------------------------- 4 files changed, 23 insertions(+), 44 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 5c198ec4..7aa11d94 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -211,7 +211,7 @@ def run(args: argparse.Namespace): atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") + logging.info(f"Atomic Energies: {atomic_energies.tolist()}") _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] avg_num_neighbors, mean, std=pool_compute_stats(_inputs) logging.info(f"Average number of neighbors: {avg_num_neighbors}") diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7dc6285c..3443f7b4 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -76,15 +76,8 @@ def run(args: argparse.Namespace) -> None: tools.set_seeds(args.seed) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) logging.info("===========CHECKING SETTINGS===========") - for message, level in input_log_messages: - if level == "debug": - logging.debug(message) - elif level == "warning": - logging.warning(message) - elif level == "error": - logging.error(message) - else: - logging.info(message) + for message, loglevel in input_log_messages: + logging.log(level=loglevel,msg=message) if args.distributed: torch.cuda.set_device(local_rank) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 0b8f5ec3..a2102caa 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -9,6 +9,7 @@ from typing import Optional from e3nn import o3 +import logging def check_args(args): @@ -46,14 +47,15 @@ def check_args(args): .irreps.simplify() ) log_messages.append( - ("Both hidden_irreps, num_channels and max_L are specified", "info") + ("Both hidden_irreps, num_channels and max_L are specified", logging.INFO) ) log_messages.append( ( f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.", - "warning", + logging.WARNING ) ) + # Loss and optimization # Check Stage Two loss start @@ -64,10 +66,10 @@ def check_args(args): log_messages.append( ( f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", - "info", + logging.INFO, ) ) - log_messages.append(("Stage Two will not start", "warning")) + log_messages.append(("Stage Two will not start", logging.WARNING)) args.swa = False return args, log_messages diff --git a/mace/tools/train.py b/mace/tools/train.py index 0e4cfdeb..33604ae6 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,15 +45,15 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) + logging_level=logging.INFO if epoch is None: - logging_level=logging.DEBUG + inintial_phrase="Initial loss on validation set" else: - logging_level=logging.INFO + inintial_phrase=f"Epoch {epoch}" if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.log(level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -62,9 +62,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 - logging.log( - level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3", + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -73,9 +71,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.log( - level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV", + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -84,8 +80,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 error_stress = eval_metrics["mae_stress"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3" + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3" ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -94,43 +89,32 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 error_virials = eval_metrics["mae_virials"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV" + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.log( - level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.log( - level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.log( - level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.log( - level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye", + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.log( - level=logging_level, - msg=f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye", + logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye", ) @@ -173,7 +157,7 @@ def train( logging.info("") logging.info("===========TRAINING===========") - logging.info("Started training") + logging.info("Started training, reporting errors on validation set") epoch = start_epoch # # log validation loss before _any_ training From ce9d7d316b06dd93586c4ee1c951308cfbdfef38 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Sun, 18 Aug 2024 23:00:12 +0200 Subject: [PATCH 28/42] fix pylint warnings --- mace/cli/run_train.py | 2 +- mace/tools/arg_parser.py | 5 ++--- mace/tools/train.py | 36 +++++++++++++++++++++++------------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3443f7b4..1656bb8e 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -77,7 +77,7 @@ def run(args: argparse.Namespace) -> None: tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) logging.info("===========CHECKING SETTINGS===========") for message, loglevel in input_log_messages: - logging.log(level=loglevel,msg=message) + logging.log(level=loglevel, msg=message) if args.distributed: torch.cuda.set_device(local_rank) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index a2102caa..a6223bb8 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -5,11 +5,11 @@ ########################################################################################### import argparse +import logging import os from typing import Optional from e3nn import o3 -import logging def check_args(args): @@ -52,10 +52,9 @@ def check_args(args): log_messages.append( ( f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.", - logging.WARNING + logging.WARNING, ) ) - # Loss and optimization # Check Stage Two loss start diff --git a/mace/tools/train.py b/mace/tools/train.py index 33604ae6..69874c9d 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,15 +45,16 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) - logging_level=logging.INFO + logging_level = logging.INFO if epoch is None: - inintial_phrase="Initial loss on validation set" + inintial_phrase = "Initial loss on validation set" else: - inintial_phrase=f"Epoch {epoch}" + inintial_phrase = f"Epoch {epoch}" if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -62,7 +63,8 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3", + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -71,7 +73,8 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV", + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -80,7 +83,8 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 error_stress = eval_metrics["mae_stress"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3" + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3" ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -89,32 +93,38 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 error_virials = eval_metrics["mae_virials"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV" + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye", + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info(f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye", + logging.info( + f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye", ) From 7310307c97821841584b13f298dfc2b89fc85fec Mon Sep 17 00:00:00 2001 From: vue1999 Date: Sun, 18 Aug 2024 23:05:50 +0200 Subject: [PATCH 29/42] removed unused variable --- mace/tools/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 69874c9d..98ec2f80 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,7 +45,6 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) - logging_level = logging.INFO if epoch is None: inintial_phrase = "Initial loss on validation set" else: From 64ca16ac04453c6ff1844f12094ad65973eefbd6 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Mon, 19 Aug 2024 20:55:31 +0200 Subject: [PATCH 30/42] Print random indices used to create valid. Fix previous mistake checking valid_batch_size. --- mace/cli/run_train.py | 6 +++--- mace/data/utils.py | 5 ++++- mace/tools/scripts_utils.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1656bb8e..00a66e65 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -161,13 +161,13 @@ def run(args: argparse.Namespace) -> None: keep_isolated_atoms=args.keep_isolated_atoms, ) if len(collections.train) < args.batch_size: - logging.warning( + logging.info( f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" ) args.batch_size = int(len(collections.train) * 0.1) logging.warning(f"Batch size changed to {args.batch_size}") - if len(collections.train) < len(collections.valid): - logging.warning( + if len(collections.valid) < args.valid_batch_size: + logging.info( f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" ) args.valid_batch_size = int(len(collections.valid) * 0.1) diff --git a/mace/data/utils.py b/mace/data/utils.py index 66020d52..4b3bf6df 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -63,7 +63,10 @@ def random_train_valid_split( indices = list(range(size)) rng = np.random.default_rng(seed) rng.shuffle(indices) - + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation" + ) + logging.info(f"Validation set created using indices: {indices[train_size:]}") return ( [items[i] for i in indices[:train_size]], [items[i] for i in indices[train_size:]], diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index a68058eb..85e248c5 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -80,7 +80,7 @@ def get_dataset_from_xyz( all_train_configs, valid_fraction, seed ) logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation [{len(valid_configs)} configurations, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + f"Loaded {len(valid_configs)} validation configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" ) test_configs = [] From e2257a54806fd6805a0b9ec4cd867a82844eff33 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Mon, 19 Aug 2024 21:58:46 +0200 Subject: [PATCH 31/42] Changing batch size won't result in 0, need to discuss how to change it --- mace/cli/run_train.py | 4 ++-- mace/data/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 00a66e65..b521f75d 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -164,13 +164,13 @@ def run(args: argparse.Namespace) -> None: logging.info( f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" ) - args.batch_size = int(len(collections.train) * 0.1) + args.batch_size = max(1, int(len(collections.train) * 0.1)) logging.warning(f"Batch size changed to {args.batch_size}") if len(collections.valid) < args.valid_batch_size: logging.info( f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" ) - args.valid_batch_size = int(len(collections.valid) * 0.1) + args.batch_size = max(1, int(len(collections.train) * 0.1)) logging.warning(f"Validation batch size changed to {args.valid_batch_size}") else: diff --git a/mace/data/utils.py b/mace/data/utils.py index 4b3bf6df..15bfc7da 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -64,7 +64,7 @@ def random_train_valid_split( rng = np.random.default_rng(seed) rng.shuffle(indices) logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation" + f"Using random {100 * valid_fraction:.0f}% of training set for validation" ) logging.info(f"Validation set created using indices: {indices[train_size:]}") return ( From 78290b2c60ba42dfdeccaa2389540728c7770b82 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Tue, 20 Aug 2024 11:47:05 +0200 Subject: [PATCH 32/42] Revert "Print random indices used to create valid. Fix previous mistake checking valid_batch_size." This reverts commit 64ca16ac04453c6ff1844f12094ad65973eefbd6. --- mace/cli/run_train.py | 6 +++--- mace/data/utils.py | 5 +---- mace/tools/scripts_utils.py | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index b521f75d..3fb73636 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -161,13 +161,13 @@ def run(args: argparse.Namespace) -> None: keep_isolated_atoms=args.keep_isolated_atoms, ) if len(collections.train) < args.batch_size: - logging.info( + logging.warning( f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" ) args.batch_size = max(1, int(len(collections.train) * 0.1)) logging.warning(f"Batch size changed to {args.batch_size}") - if len(collections.valid) < args.valid_batch_size: - logging.info( + if len(collections.train) < len(collections.valid): + logging.warning( f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" ) args.batch_size = max(1, int(len(collections.train) * 0.1)) diff --git a/mace/data/utils.py b/mace/data/utils.py index 15bfc7da..66020d52 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -63,10 +63,7 @@ def random_train_valid_split( indices = list(range(size)) rng = np.random.default_rng(seed) rng.shuffle(indices) - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation" - ) - logging.info(f"Validation set created using indices: {indices[train_size:]}") + return ( [items[i] for i in indices[:train_size]], [items[i] for i in indices[train_size:]], diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 85e248c5..a68058eb 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -80,7 +80,7 @@ def get_dataset_from_xyz( all_train_configs, valid_fraction, seed ) logging.info( - f"Loaded {len(valid_configs)} validation configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + f"Using random {100 * valid_fraction:.0f}% of training set for validation [{len(valid_configs)} configurations, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" ) test_configs = [] From 4801f8b99d51139189b47831d7eae41a2d512520 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Tue, 20 Aug 2024 11:48:31 +0200 Subject: [PATCH 33/42] Revert "Changing batch size won't result in 0, need to discuss how to change it" This reverts commit e2257a54806fd6805a0b9ec4cd867a82844eff33. --- mace/cli/run_train.py | 4 ++-- mace/data/utils.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3fb73636..1656bb8e 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -164,13 +164,13 @@ def run(args: argparse.Namespace) -> None: logging.warning( f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" ) - args.batch_size = max(1, int(len(collections.train) * 0.1)) + args.batch_size = int(len(collections.train) * 0.1) logging.warning(f"Batch size changed to {args.batch_size}") if len(collections.train) < len(collections.valid): logging.warning( f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" ) - args.batch_size = max(1, int(len(collections.train) * 0.1)) + args.valid_batch_size = int(len(collections.valid) * 0.1) logging.warning(f"Validation batch size changed to {args.valid_batch_size}") else: diff --git a/mace/data/utils.py b/mace/data/utils.py index 66020d52..4b3bf6df 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -63,7 +63,10 @@ def random_train_valid_split( indices = list(range(size)) rng = np.random.default_rng(seed) rng.shuffle(indices) - + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation" + ) + logging.info(f"Validation set created using indices: {indices[train_size:]}") return ( [items[i] for i in indices[:train_size]], [items[i] for i in indices[train_size:]], From c84f0f55cdc1408c303f70003aaf56a60d2808f6 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Tue, 20 Aug 2024 11:58:12 +0200 Subject: [PATCH 34/42] Print indices when creating valid, only warn about batch size errors --- mace/cli/run_train.py | 6 +----- mace/tools/scripts_utils.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1656bb8e..c4424a6b 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -164,14 +164,10 @@ def run(args: argparse.Namespace) -> None: logging.warning( f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" ) - args.batch_size = int(len(collections.train) * 0.1) - logging.warning(f"Batch size changed to {args.batch_size}") - if len(collections.train) < len(collections.valid): + if len(collections.valid) < args.valid_batch_size: logging.warning( f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" ) - args.valid_batch_size = int(len(collections.valid) * 0.1) - logging.warning(f"Validation batch size changed to {args.valid_batch_size}") else: atomic_energies_dict = None diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index a68058eb..fb2c2f62 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -80,7 +80,7 @@ def get_dataset_from_xyz( all_train_configs, valid_fraction, seed ) logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation [{len(valid_configs)} configurations, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + f"Selected {len(valid_configs)} configurations for validation [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" ) test_configs = [] From 1ffa16a550eedbb5604c8d86d0bfe0060e15274d Mon Sep 17 00:00:00 2001 From: vue1999 Date: Wed, 21 Aug 2024 17:40:09 +0200 Subject: [PATCH 35/42] Moved check_args into arg_parser_tools, and moved the checks for hidden_irreps, max_L, num_channels --- mace/cli/run_train.py | 13 ----- mace/tools/__init__.py | 2 +- mace/tools/arg_parser.py | 62 ------------------------ mace/tools/arg_parser_tools.py | 86 ++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 76 deletions(-) create mode 100644 mace/tools/arg_parser_tools.py diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c4424a6b..dc45620a 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -419,19 +419,6 @@ def run(args: argparse.Namespace) -> None: model_config = model_config_foundation # pylint else: logging.debug("Building model") - if args.num_channels is not None and args.max_L is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - assert args.max_L >= 0, "max_L must be non-negative integer" - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - logging.info( f"Hidden irreps: {args.hidden_irreps} (Number of channel: {args.num_channels}, max_L: {args.max_L})" ) diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 1234cf45..6321d709 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,8 +1,8 @@ from .arg_parser import ( build_default_arg_parser, build_preprocess_arg_parser, - check_args, ) +from .arg_parser_tools import check_args from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .finetuning_utils import load_foundations diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index a6223bb8..a108e2f4 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -5,73 +5,11 @@ ########################################################################################### import argparse -import logging import os from typing import Optional -from e3nn import o3 -def check_args(args): - """ - Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing - the (potentially) modified args and a list of log messages. - """ - log_messages = [] - - # Directories - # Use work_dir for all other directories as well, unless they were specified by the user - if args.log_dir is None: - args.log_dir = os.path.join(args.work_dir, "logs") - if args.model_dir is None: - args.model_dir = args.work_dir - if args.checkpoints_dir is None: - args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") - if args.results_dir is None: - args.results_dir = os.path.join(args.work_dir, "results") - if args.downloads_dir is None: - args.downloads_dir = os.path.join(args.work_dir, "downloads") - - # Model - # Check if hidden_irreps, num_channels and max_L are consistent - if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: - args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 - elif ( - args.hidden_irreps is not None - and args.num_channels is not None - and args.max_L is not None - ): - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - log_messages.append( - ("Both hidden_irreps, num_channels and max_L are specified", logging.INFO) - ) - log_messages.append( - ( - f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.", - logging.WARNING, - ) - ) - - # Loss and optimization - # Check Stage Two loss start - if args.swa: - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - if args.start_swa > args.max_num_epochs: - log_messages.append( - ( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", - logging.INFO, - ) - ) - log_messages.append(("Stage Two will not start", logging.WARNING)) - args.swa = False - - return args, log_messages def build_default_arg_parser() -> argparse.ArgumentParser: diff --git a/mace/tools/arg_parser_tools.py b/mace/tools/arg_parser_tools.py new file mode 100644 index 00000000..e0fac87d --- /dev/null +++ b/mace/tools/arg_parser_tools.py @@ -0,0 +1,86 @@ +import os +from e3nn import o3 +import logging + +def check_args(args): + """ + Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing + the (potentially) modified args and a list of log messages. + """ + log_messages = [] + + # Directories + # Use work_dir for all other directories as well, unless they were specified by the user + if args.log_dir is None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir is None: + args.model_dir = args.work_dir + if args.checkpoints_dir is None: + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir is None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir is None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") + + # Model + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif ( + args.hidden_irreps is not None + and args.num_channels is not None + and args.max_L is not None + ): + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + log_messages.append( + ("All of hidden_irreps, num_channels and max_L are specified", logging.WARNING) + ) + log_messages.append( + ( + f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.", + logging.WARNING, + ) + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.num_channels is not None and args.max_L is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + assert args.max_L >= 0, "max_L must be non-negative integer" + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.hidden_irreps is not None: + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + args.num_channels=list({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)})[0] + args.max_L=o3.Irreps(args.hidden_irreps).lmax + + + # Loss and optimization + # Check Stage Two loss start + if args.swa: + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + if args.start_swa > args.max_num_epochs: + log_messages.append( + ( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + logging.INFO, + ) + ) + log_messages.append(("Stage Two will not start", logging.WARNING)) + args.swa = False + + return args, log_messages \ No newline at end of file From 2e20719e468353d67e83a1fdebe600b0ae627031 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Wed, 21 Aug 2024 18:39:52 +0200 Subject: [PATCH 36/42] Rearranged model and optimisation information prints --- mace/cli/run_train.py | 64 +++++++++++++++------------------- mace/data/utils.py | 2 +- mace/tools/__init__.py | 5 +-- mace/tools/arg_parser.py | 3 -- mace/tools/arg_parser_tools.py | 37 +++++++++++++------- 5 files changed, 55 insertions(+), 56 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index dc45620a..9d0db42f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -351,7 +351,6 @@ def run(args: argparse.Namespace) -> None: else: # Unweighted Energy and Forces loss by default loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) - logging.info(loss_fn) if args.compute_avg_num_neighbors: avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) @@ -391,7 +390,7 @@ def run(args: argparse.Namespace) -> None: } logging.info( - f"Selected the following values to use and report: {[report for report, value in output_args.items() if value]}" + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" ) if args.scaling == "no_scaling": @@ -403,7 +402,7 @@ def run(args: argparse.Namespace) -> None: ) # Build model if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.debug("Building model") + logging.info("Loading foundation model") model_config_foundation = extract_config_mace_model(model_foundation) model_config_foundation["atomic_numbers"] = z_table.zs model_config_foundation["num_elements"] = len(z_table) @@ -418,11 +417,19 @@ def run(args: argparse.Namespace) -> None: args.model = "FoundationMACE" model_config = model_config_foundation # pylint else: - logging.debug("Building model") + logging.info("Building model") logging.info( - f"Hidden irreps: {args.hidden_irreps} (Number of channel: {args.num_channels}, max_L: {args.max_L})" + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers with correlation: {args.correlation} and spherical harmonics up to: {args.max_ell}" + ) + logging.info( + f"Radial cutoff: {args.r_max} Å, {args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" ) - model_config = dict( r_max=args.r_max, num_bessel=args.num_radial_basis, @@ -534,6 +541,18 @@ def run(args: argparse.Namespace) -> None: ) model.to(device) + logging.debug(model) + logging.info(f"Total number of parameters: {tools.count_parameters(model)}") + logging.info("") + logging.info("===========OPTIMIZER INFORMATION===========") + logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") + logging.info(f"Batch size: {args.batch_size}") + logging.info( + f"Number of gradient updates: {args.max_num_epochs*len(collections.train)/args.batch_size}" + ) + logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.info(loss_fn) + # Optimizer decay_interactions = {} no_decay_interactions = {} @@ -604,6 +623,9 @@ def run(args: argparse.Namespace) -> None: swas.append(True) if args.start_swa is None: args.start_swa = max(1, args.max_num_epochs // 4 * 3) + logging.info( + f"Stage Two will start after {args.start_swa} epochs with loss function:" + ) if args.loss == "forces_only": raise ValueError("Can not select Stage Two with forces only loss.") if args.loss == "virials": @@ -624,17 +646,12 @@ def run(args: argparse.Namespace) -> None: forces_weight=args.swa_forces_weight, dipole_weight=args.swa_dipole_weight, ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" - ) else: loss_fn_energy = modules.WeightedEnergyForcesLoss( energy_weight=args.swa_energy_weight, forces_weight=args.swa_forces_weight, ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" - ) + logging.info(loss_fn_energy) swa = tools.SWAContainer( model=AveragedModel(model), scheduler=SWALR( @@ -678,29 +695,6 @@ def run(args: argparse.Namespace) -> None: for group in optimizer.param_groups: group["lr"] = args.lr - logging.debug(model) - logging.info(f"Total number of parameters: {tools.count_parameters(model)}") - logging.info( - f"Batch size: {args.batch_size}, validation batch size: {args.valid_batch_size}" - ) - logging.info( - f"Number of gradient updates: {args.max_num_epochs*len(collections.train)/args.batch_size}" - ) - logging.info( - f"Radial cutoff: {args.r_max}, num_radial_basis: {args.num_radial_basis}, num_cutoff_basis: {args.num_cutoff_basis}" - ) - logging.info( - f"Polynomial cutoff: {args.num_cutoff_basis}, max_L: {args.max_L}, num_interactions: {args.num_interactions}" - ) - logging.info( - f"Correlation: {args.correlation}, distance transform: {args.distance_transform}" - ) - - logging.info("") - logging.info("===========OPTIMIZER INFORMATION===========") - logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") - logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") - if args.wandb: logging.info("Using Weights and Biases for logging") import wandb diff --git a/mace/data/utils.py b/mace/data/utils.py index 4b3bf6df..15bfc7da 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -64,7 +64,7 @@ def random_train_valid_split( rng = np.random.default_rng(seed) rng.shuffle(indices) logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation" + f"Using random {100 * valid_fraction:.0f}% of training set for validation" ) logging.info(f"Validation set created using indices: {indices[train_size:]}") return ( diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 6321d709..54c59455 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,7 +1,4 @@ -from .arg_parser import ( - build_default_arg_parser, - build_preprocess_arg_parser, -) +from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser from .arg_parser_tools import check_args from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index a108e2f4..f91d580c 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -9,9 +9,6 @@ from typing import Optional - - - def build_default_arg_parser() -> argparse.ArgumentParser: try: import configargparse diff --git a/mace/tools/arg_parser_tools.py b/mace/tools/arg_parser_tools.py index e0fac87d..52f35e87 100644 --- a/mace/tools/arg_parser_tools.py +++ b/mace/tools/arg_parser_tools.py @@ -1,6 +1,8 @@ +import logging import os + from e3nn import o3 -import logging + def check_args(args): """ @@ -37,11 +39,14 @@ def check_args(args): .irreps.simplify() ) log_messages.append( - ("All of hidden_irreps, num_channels and max_L are specified", logging.WARNING) + ( + "All of hidden_irreps, num_channels and max_L are specified", + logging.WARNING, + ) ) log_messages.append( ( - f"Using num_channels and max_L to create hidden irreps: {args.hidden_irreps}.", + f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", logging.WARNING, ) ) @@ -53,9 +58,9 @@ def check_args(args): assert args.max_L >= 0, "max_L must be non-negative integer" args.hidden_irreps = o3.Irreps( (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) + .sort() + .irreps.simplify() + ) assert ( len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" @@ -64,9 +69,10 @@ def check_args(args): len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - args.num_channels=list({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)})[0] - args.max_L=o3.Irreps(args.hidden_irreps).lmax - + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} + )[0] + args.max_L = o3.Irreps(args.hidden_irreps).lmax # Loss and optimization # Check Stage Two loss start @@ -76,11 +82,16 @@ def check_args(args): if args.start_swa > args.max_num_epochs: log_messages.append( ( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", - logging.INFO, + f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + logging.WARNING, + ) + ) + log_messages.append( + ( + "Stage Two will not start, as start_stage_two > max_num_epochs", + logging.WARNING, ) ) - log_messages.append(("Stage Two will not start", logging.WARNING)) args.swa = False - return args, log_messages \ No newline at end of file + return args, log_messages From 1196fc0dc0c92fd0457561933317a6e08c93f5b3 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Wed, 21 Aug 2024 19:05:02 +0200 Subject: [PATCH 37/42] Small phrashing changes and changing eval_interval=1, while fixing them at 2 for the tests --- mace/cli/run_train.py | 8 +++++--- mace/tools/arg_parser.py | 2 +- mace/tools/train.py | 2 +- tests/test_calculator.py | 5 +++++ tests/test_run_train.py | 1 + 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 9d0db42f..1bcc1415 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -545,10 +545,12 @@ def run(args: argparse.Namespace) -> None: logging.info(f"Total number of parameters: {tools.count_parameters(model)}") logging.info("") logging.info("===========OPTIMIZER INFORMATION===========") - logging.info(f"Optimizer for parameter optimization: {args.optimizer.upper()}") + logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") logging.info(f"Batch size: {args.batch_size}") + if args.ema: + logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") logging.info( - f"Number of gradient updates: {args.max_num_epochs*len(collections.train)/args.batch_size}" + f"Number of gradient updates: {int(args.max_num_epochs*len(collections.train)/args.batch_size)}" ) logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") logging.info(loss_fn) @@ -820,7 +822,7 @@ def run(args: argparse.Namespace) -> None: if swa_eval: logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") else: - logging.info(f"Loaded model from epoch {epoch} for evaluation") + logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") for param in model.parameters(): param.requires_grad = False diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index f91d580c..1eb033be 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -560,7 +560,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=True, ) parser.add_argument( - "--eval_interval", help="evaluate model every epochs", type=int, default=2 + "--eval_interval", help="evaluate model every epochs", type=int, default=1 ) parser.add_argument( "--keep_checkpoints", diff --git a/mace/tools/train.py b/mace/tools/train.py index 98ec2f80..7acc1b6f 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -46,7 +46,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["epoch"] = epoch logger.log(eval_metrics) if epoch is None: - inintial_phrase = "Initial loss on validation set" + inintial_phrase = "Initial metrics on validation set" else: inintial_phrase = f"Epoch {epoch}" if log_errors == "PerAtomRMSE": diff --git a/tests/test_calculator.py b/tests/test_calculator.py index bc8f5862..73019b4a 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -75,6 +75,7 @@ def trained_model_fixture(tmp_path_factory, fitting_configs): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -137,6 +138,7 @@ def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -200,6 +202,7 @@ def trained_dipole_fixture(tmp_path_factory, fitting_configs): "stress_key": "", "dipole_key": "REF_dipole", "error_table": "DipoleRMSE", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -265,6 +268,7 @@ def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): "stress_key": "", "dipole_key": "REF_dipole", "error_table": "EnergyDipoleRMSE", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -332,6 +336,7 @@ def trained_committee_fixture(tmp_path_factory, fitting_configs): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp(f"run{seed}_") diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 11befd2a..59f7c595 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -66,6 +66,7 @@ def fixture_fitting_configs(): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, } From 139773add0321ceb0e3cc63b5a61b24168df55ee Mon Sep 17 00:00:00 2001 From: vue1999 Date: Thu, 22 Aug 2024 22:11:27 +0200 Subject: [PATCH 38/42] Rephrasing loading data section, saving random indices into a file when it's more than 10 for creating valid --- mace/cli/run_train.py | 6 ++++-- mace/data/utils.py | 20 +++++++++++++++----- mace/tools/arg_parser_tools.py | 16 ++++++++++++++++ mace/tools/scripts_utils.py | 18 ++++++++++-------- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1bcc1415..b4f3c39c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -146,6 +146,7 @@ def run(args: argparse.Namespace) -> None: ), "valid_file if given must be same format as train_file" config_type_weights = get_config_type_weights(args.config_type_weights) collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, @@ -161,13 +162,14 @@ def run(args: argparse.Namespace) -> None: keep_isolated_atoms=args.keep_isolated_atoms, ) if len(collections.train) < args.batch_size: - logging.warning( + logging.error( f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" ) if len(collections.valid) < args.valid_batch_size: logging.warning( f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" ) + args.valid_batch_size = len(collections.valid) else: atomic_energies_dict = None @@ -239,7 +241,7 @@ def run(args: argparse.Namespace) -> None: [atomic_energies_dict[z] for z in z_table.zs] ) logging.info( - f"Atomic Energies used [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}" + f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}" ) if args.train_file.endswith(".xyz"): diff --git a/mace/data/utils.py b/mace/data/utils.py index 15bfc7da..7812743d 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -53,7 +53,7 @@ class Configuration: def random_train_valid_split( - items: Sequence, valid_fraction: float, seed: int + items: Sequence, valid_fraction: float, seed: int, work_dir: str ) -> Tuple[List, List]: assert 0.0 < valid_fraction < 1.0 @@ -63,10 +63,20 @@ def random_train_valid_split( indices = list(range(size)) rng = np.random.default_rng(seed) rng.shuffle(indices) - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation" - ) - logging.info(f"Validation set created using indices: {indices[train_size:]}") + if len(indices[train_size:]) < 10: + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" + ) + else: + # Save indices to file + with open(work_dir + f"/valid_indices_{seed}.txt", "w") as f: + for index in indices[train_size:]: + f.write(f"{index}\n") + + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" + ) + return ( [items[i] for i in indices[:train_size]], [items[i] for i in indices[train_size:]], diff --git a/mace/tools/arg_parser_tools.py b/mace/tools/arg_parser_tools.py index 52f35e87..da64806a 100644 --- a/mace/tools/arg_parser_tools.py +++ b/mace/tools/arg_parser_tools.py @@ -73,6 +73,22 @@ def check_args(args): {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} )[0] args.max_L = o3.Irreps(args.hidden_irreps).lmax + elif args.max_L is not None and args.num_channels is None: + assert args.max_L >= 0, "max_L must be non-negative integer" + args.num_channels = 128 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + elif args.max_L is None and args.num_channels is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + args.max_L = 1 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) # Loss and optimization # Check Stage Two loss start diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index fb2c2f62..be05412b 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -29,6 +29,7 @@ class SubsetCollection: def get_dataset_from_xyz( + work_dir: str, train_path: str, valid_path: str, valid_fraction: float, @@ -57,7 +58,7 @@ def get_dataset_from_xyz( keep_isolated_atoms=keep_isolated_atoms, ) logging.info( - f"Loaded {len(all_train_configs)} training configurations [{np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] from '{train_path}'" + f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" ) if valid_path is not None: _, valid_configs = data.load_from_xyz( @@ -72,15 +73,15 @@ def get_dataset_from_xyz( extract_atomic_energies=False, ) logging.info( - f"Loaded {len(valid_configs)} validation configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] from '{valid_path}'" + f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" ) train_configs = all_train_configs else: train_configs, valid_configs = data.random_train_valid_split( - all_train_configs, valid_fraction, seed + all_train_configs, valid_fraction, seed, work_dir ) logging.info( - f"Selected {len(valid_configs)} configurations for validation [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" ) test_configs = [] @@ -99,11 +100,12 @@ def get_dataset_from_xyz( # create list of tuples (config_type, list(Atoms)) test_configs = data.test_config_types(all_test_configs) logging.info( - f"Loaded {len(all_test_configs)} test configurations from '{test_path}':" - ) - logging.info( - f"{'; '.join([f'{name}: {len(test_configs)} configs, {np.sum([1 if config.energy else 0 for config in test_configs])} energy, {np.sum([config.forces.size for config in test_configs])} forces' for name, test_configs in test_configs])}" + f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" ) + for name, tmp_configs in test_configs: + logging.info( + f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" + ) return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), From a83573128bdc62076ec1898dbebb4922fe1ee34e Mon Sep 17 00:00:00 2001 From: vue1999 Date: Thu, 22 Aug 2024 22:22:19 +0200 Subject: [PATCH 39/42] fixing pylint errors --- mace/cli/preprocess_data.py | 1 + mace/data/utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 7aa11d94..de34b1d4 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -155,6 +155,7 @@ def run(args: argparse.Namespace): # Data preparation collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, diff --git a/mace/data/utils.py b/mace/data/utils.py index 7812743d..78e3e76f 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -69,7 +69,7 @@ def random_train_valid_split( ) else: # Save indices to file - with open(work_dir + f"/valid_indices_{seed}.txt", "w") as f: + with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: for index in indices[train_size:]: f.write(f"{index}\n") From 1fbb37f39af5935d922436524f24c714bc6f59a5 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Mon, 26 Aug 2024 20:03:34 +0100 Subject: [PATCH 40/42] Change model details output and also print them when using foundation model --- mace/cli/run_train.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index b4f3c39c..89687af7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -404,11 +404,14 @@ def run(args: argparse.Namespace) -> None: ) # Build model if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Loading foundation model") + logging.info("Loading FOUNDATION model") model_config_foundation = extract_config_mace_model(model_foundation) model_config_foundation["atomic_numbers"] = z_table.zs model_config_foundation["num_elements"] = len(z_table) args.max_L = model_config_foundation["hidden_irreps"].lmax + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(model_config_foundation["hidden_irreps"])} + )[0] model_config_foundation["atomic_inter_shift"] = ( model_foundation.scale_shift.shift.item() ) @@ -418,16 +421,31 @@ def run(args: argparse.Namespace) -> None: model_config_foundation["atomic_energies"] = atomic_energies args.model = "FoundationMACE" model_config = model_config_foundation # pylint + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} Å (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} Å)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) else: logging.info("Building model") logging.info( f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" ) logging.info( - f"{args.num_interactions} layers with correlation: {args.correlation} and spherical harmonics up to: {args.max_ell}" + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" ) logging.info( - f"Radial cutoff: {args.r_max} Å, {args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + f"Radial cutoff: {args.r_max} Å (total receptive field for each atom: {args.r_max * args.num_interactions} Å)" ) logging.info( f"Distance transform for radial basis functions: {args.distance_transform}" From 55ebcdbf725f6b036dbdb2b1b70149735a22b206 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Mon, 26 Aug 2024 23:08:13 +0100 Subject: [PATCH 41/42] phrasing changes --- mace/cli/run_train.py | 2 +- mace/tools/train.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 89687af7..f98c7a04 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -75,7 +75,7 @@ def run(args: argparse.Namespace) -> None: # Setup tools.set_seeds(args.seed) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) - logging.info("===========CHECKING SETTINGS===========") + logging.info("===========VERIFYING SETTINGS===========") for message, loglevel in input_log_messages: logging.log(level=loglevel, msg=message) diff --git a/mace/tools/train.py b/mace/tools/train.py index 7acc1b6f..2f39bed2 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -46,7 +46,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["epoch"] = epoch logger.log(eval_metrics) if epoch is None: - inintial_phrase = "Initial metrics on validation set" + inintial_phrase = "Initial" else: inintial_phrase = f"Epoch {epoch}" if log_errors == "PerAtomRMSE": @@ -167,6 +167,7 @@ def train( logging.info("") logging.info("===========TRAINING===========") logging.info("Started training, reporting errors on validation set") + logging.info("Loss metrics on validation set") epoch = start_epoch # # log validation loss before _any_ training From 5781abac03ca58856cadea69ed429156466bfb5f Mon Sep 17 00:00:00 2001 From: vue1999 Date: Mon, 26 Aug 2024 23:47:19 +0100 Subject: [PATCH 42/42] Training loss and error table formatting --- mace/tools/scripts_utils.py | 74 ++++++++++++++++++------------------- mace/tools/train.py | 20 +++++----- 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index be05412b..27455944 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -539,18 +539,18 @@ def create_error_table( table.add_row( [ name, - f"{metrics['rmse_e'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", ] ) elif table_type == "PerAtomRMSE": table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", ] ) elif ( @@ -560,10 +560,10 @@ def create_error_table( table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", - f"{metrics['rmse_stress'] * 1000:.1f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", ] ) elif ( @@ -573,10 +573,10 @@ def create_error_table( table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", - f"{metrics['rmse_virials'] * 1000:.1f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", ] ) elif ( @@ -586,10 +586,10 @@ def create_error_table( table.add_row( [ name, - f"{metrics['mae_e_per_atom'] * 1000:.1f}", - f"{metrics['mae_f'] * 1000:.1f}", - f"{metrics['rel_mae_f']:.2f}", - f"{metrics['mae_stress'] * 1000:.1f}", + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", ] ) elif ( @@ -599,55 +599,55 @@ def create_error_table( table.add_row( [ name, - f"{metrics['mae_e_per_atom'] * 1000:.1f}", - f"{metrics['mae_f'] * 1000:.1f}", - f"{metrics['rel_mae_f']:.2f}", - f"{metrics['mae_virials'] * 1000:.1f}", + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", ] ) elif table_type == "TotalMAE": table.add_row( [ name, - f"{metrics['mae_e'] * 1000:.1f}", - f"{metrics['mae_f'] * 1000:.1f}", - f"{metrics['rel_mae_f']:.2f}", + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", ] ) elif table_type == "PerAtomMAE": table.add_row( [ name, - f"{metrics['mae_e_per_atom'] * 1000:.1f}", - f"{metrics['mae_f'] * 1000:.1f}", - f"{metrics['rel_mae_f']:.2f}", + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", ] ) elif table_type == "DipoleRMSE": table.add_row( [ name, - f"{metrics['rmse_mu_per_atom'] * 1000:.2f}", - f"{metrics['rel_rmse_mu']:.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", ] ) elif table_type == "DipoleMAE": table.add_row( [ name, - f"{metrics['mae_mu_per_atom'] * 1000:.2f}", - f"{metrics['rel_mae_mu']:.1f}", + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", ] ) elif table_type == "EnergyDipoleRMSE": table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.1f}", - f"{metrics['rmse_mu_per_atom'] * 1000:.1f}", - f"{metrics['rel_rmse_mu']:.1f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", ] ) return table diff --git a/mace/tools/train.py b/mace/tools/train.py index 2f39bed2..b38bce16 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -53,7 +53,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -63,7 +63,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3", + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -73,7 +73,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV", + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -83,7 +83,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["mae_f"] * 1e3 error_stress = eval_metrics["mae_stress"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3" + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -93,37 +93,37 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["mae_f"] * 1e3 error_virials = eval_metrics["mae_virials"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV" + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A", + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A", + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye", + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye", + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", )