diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml new file mode 100644 index 00000000..f6614d48 --- /dev/null +++ b/.github/workflows/run-tests.yaml @@ -0,0 +1,39 @@ + +name: Run Tests +on: [push] +jobs: + Run-Mistral-Tests: + runs-on: self-hosted + steps: + - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." + - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!" + - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}." + - name: Check out repository code + uses: actions/checkout@v2 + - run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner." + - run: echo "🖥️ The workflow is now ready to test your code on the runner." + - name: Setup + run: | + cp -r /home/stanzabuild/mistral/wandb . + wandb offline + - name: Tests for arguments (single node/single GPU) + if: always() + run: | + cd tests + CUDA_VISIBLE_DEVICES=0 pytest test_args.py + - name: Tests for checkpoints (single node/single GPU) + if: always() + run: | + cd tests + CUDA_VISIBLE_DEVICES=0 pytest test_checkpoint.py + - name: Tests for upcasting (single node/single GPU) + if: always() + run: | + cd tests + CUDA_VISIBLE_DEVICES=0 pytest test_fp.py + - name: Tests for random seed (single node/single GPU) + if: always() + run: | + cd tests + CUDA_VISIBLE_DEVICES=0 pytest test_seed.py + - run: echo "🍏 This job's status is ${{ job.status }}." diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b6dd730..f7815dbd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ fail_fast: true repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v4.0.1 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -17,17 +17,22 @@ repos: - id: check-added-large-files - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 21.8b0 hooks: - id: black - repo: https://github.com/timothycrosley/isort - rev: 5.6.4 + rev: 5.9.3 hooks: - id: isort - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.2 hooks: - id: flake8 additional_dependencies: [flake8-isort] + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v0.910' + hooks: + - id: mypy diff --git a/README.md b/README.md index 62d5cc23..dfd11967 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-green?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) A framework for transparent and accessible large-scale language model training, built with [Hugging Face 🤗](https://huggingface.co/) . Includes tools -and helpful scripts for incorporating new pre-training datasets, various schemes for single node and distributed training - including on +and helpful scripts for incorporating new pre-training datasets, various schemes for single node and distributed training - including on cloud providers like GCP, and importantly, scripts for evaluation. Visit our [Read the Docs](https://nlp.stanford.edu/mistral) for the full documentation. @@ -143,8 +143,8 @@ We have also stored over 600 checkpoints for each model, subject to the followin - Every 100 Steps, from 2000 - 20,000 Steps. - Every 1000 Steps, from 20,000 - 400,000 Steps. -This comes out to _610 checkpoints per run, taking up ~22TB for all 10 models_ (making it pretty expensive to host!) If you are interested in acquiring -these additional checkpoints, please [file an issue](https://github.com/stanford-crfm/mistral/issues) or contact Laurel (lorr1) and Sidd (skaramcheti) +This comes out to _610 checkpoints per run, taking up ~22TB for all 10 models_ (making it pretty expensive to host!) If you are interested in acquiring +these additional checkpoints, please [file an issue](https://github.com/stanford-crfm/mistral/issues) or contact Laurel (lorr1) and Sidd (skaramcheti) at their @cs.stanford.edu email addresses, and we'll be happy to figure out a cost-effective solution to sharing them. GPT-2 Medium @@ -201,11 +201,18 @@ GPT-2 Small ## Issues -To ask questions, report issues, or request features, please use the [GitHub Issue Tracker](https://github.com/stanford-crfm/mistral/issues). +To ask questions, report issues, or request features, please use the [GitHub Issue Tracker](https://github.com/stanford-crfm/mistral/issues). Before creating a new issue, please make sure to search for existing issues that may solve your problem. --- +## Differences between Mistral and Hugging Face + +Please visit the [following page](https://nlp.stanford.edu/mistral/hugging_face_differences.html) that outlines the +differences between the two codebases. + +--- + ## Contributing Please see the [following page](https://nlp.stanford.edu/mistral/contributing.html) for information on contributing. diff --git a/conf/train_schema.py b/conf/train_schema.py index 803349f4..8ac2fecc 100644 --- a/conf/train_schema.py +++ b/conf/train_schema.py @@ -21,7 +21,7 @@ def get_schema() -> Dict[str, Any]: - """ Get the Cerberus schema for the Quinine config used in train.py. """ + """Get the Cerberus schema for the Quinine config used in train.py.""" # Schema for Dataset data_schema = { diff --git a/docs/README.md b/docs/README.md index 1c7fc151..53c5c3ce 100644 --- a/docs/README.md +++ b/docs/README.md @@ -4,6 +4,7 @@ If you don't already have Sphinx set up install it with `pip`. ```bash pip install sphinx +pip install sphinx-rtd-theme ``` The documentation has been built with version 4.0.2. diff --git a/docs/conf.py b/docs/conf.py index cb71474c..f062e419 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,16 +12,18 @@ # Problems with imports? Could try `export PYTHONPATH=$PYTHONPATH:`pwd`` from root project dir... import os import sys -sys.path.insert(0, os.path.abspath('..')) # Source code dir relative to this file + + +sys.path.insert(0, os.path.abspath("..")) # Source code dir relative to this file # -- Project information ----------------------------------------------------- -project = 'Mistral' -author = 'Project Mercury' -copyright = '2021 The Board of Trustees of The Leland Stanford Junior University' +project = "Mistral" +author = "Project Mercury" +copyright = "2021 The Board of Trustees of The Leland Stanford Junior University" # The full version, including alpha/beta/rc tags -release = '0.1.0' +release = "0.1.0" # -- General configuration --------------------------------------------------- @@ -29,10 +31,10 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', # Core Sphinx library for auto html doc generation from docstrings - 'sphinx.ext.autosummary', # Create neat summary tables for modules/classes/methods etc - 'sphinx.ext.intersphinx', # Link to other project's documentation (see mapping below) - 'sphinx.ext.viewcode' # Add a link to the Python source code for classes, functions etc. + "sphinx.ext.autodoc", # Core Sphinx library for auto html doc generation from docstrings + "sphinx.ext.autosummary", # Create neat summary tables for modules/classes/methods etc + "sphinx.ext.intersphinx", # Link to other project's documentation (see mapping below) + "sphinx.ext.viewcode", # Add a link to the Python source code for classes, functions etc. ] # Mappings for sphinx.ext.intersphinx. Projects have to have Sphinx-generated doc! (.inv file) @@ -46,20 +48,20 @@ autodoc_inherit_docstrings = True # If no docstring, inherit from base class set_type_checking_flag = True # Enable 'expensive' imports for sphinx_autodoc_typehints nbsphinx_allow_errors = True # Continue through Jupyter errors -#autodoc_typehints = "description" # Sphinx-native method. Not as good as sphinx_autodoc_typehints -add_module_names = False # Remove namespaces from class/method signatures +# autodoc_typehints = "description" # Sphinx-native method. Not as good as sphinx_autodoc_typehints +add_module_names = False # Remove namespaces from class/method signatures # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # Exclusions # To exclude a module, use autodoc_mock_imports. Note this may increase build time, a lot. # (Also, when installing on readthedocs.org, we omit installing Tensorflow and # Tensorflow Probability so mock them here instead.) -#autodoc_mock_imports = [ - # 'tensorflow', - # 'tensorflow_probability', -#] +# autodoc_mock_imports = [ +# 'tensorflow', +# 'tensorflow_probability', +# ] # To exclude a class, function, method or attribute, use autodoc-skip-member. (Note this can also # be used in reverse, ie. to re-include a particular member that has been excluded.) # 'Private' and 'special' members (_ and __) are excluded using the Jinja2 templates; from the main @@ -88,17 +90,18 @@ on_rtd = os.environ.get("READTHEDOCS", None) == "True" if not on_rtd: # only import and set the theme if we're building docs locally import sphinx_rtd_theme + html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] -html_css_files = ["readthedocs-custom.css"] # Override some CSS settings +html_css_files = ["readthedocs-custom.css"] # Override some CSS settings # Pydata theme -#html_theme = "pydata_sphinx_theme" -#html_logo = "_static/logo-company.png" -#html_theme_options = { "show_prev_next": False} -#html_css_files = ['pydata-custom.css'] +# html_theme = "pydata_sphinx_theme" +# html_logo = "_static/logo-company.png" +# html_theme_options = { "show_prev_next": False} +# html_css_files = ['pydata-custom.css'] # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/docs/hugging_face_differences.rst b/docs/hugging_face_differences.rst new file mode 100644 index 00000000..57abe1ce --- /dev/null +++ b/docs/hugging_face_differences.rst @@ -0,0 +1,36 @@ +Differences between Mistral and Hugging Face +=============== + +Mistral is not a replacement for Hugging Face. Rather, we extend the current functionalities in Hugging Face +by fixing stability issues with GPT training, adding evaluation scripts and supporting distributed training +with the DeepSpeed optimization library. + + +**Stability** + +When training GPT-2 Small models with Hugging Face, some of the models crashed due to numerical instability. +We fixed the this issue by rearranging the order of operations in scaled dot-product attention computation +and upcasting to FP32. We also scaled down the weights by dividing by the layer number to prevent overflow. + + +**Evaluation** + +We added online evaluation so we can get PPL on arbitrary datasets while training. + + +**Parallelism** + +We noticed that integrating parallelism (e.g. tensor model-parallelism and pipelining) breaks the current +Hugging Face APIs. + + +**Distributed Training** + +We provide ready-to-use scripts and configuration files to run distributed training with DeepSpeed, +Google Cloud Platform and Kubernetes. + + +**Future** + +We are closely working with folks from Hugging Face. We plan to integrate Mistral into the Hugging Face library +in the future diff --git a/docs/index.rst b/docs/index.rst index 0792384b..ffcfb7ef 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -27,6 +27,7 @@ Contributing API reference <_autosummary/src> + Differences between Mistral and Hugging Face Mistral - Large Scale Language Modeling Made Easy ===================================================== diff --git a/environments/environment-cpu.yaml b/environments/environment-cpu.yaml index 964a57c8..496d2cde 100644 --- a/environments/environment-cpu.yaml +++ b/environments/environment-cpu.yaml @@ -174,6 +174,7 @@ dependencies: - pycodestyle==2.6.0 - pyflakes==2.2.0 - pylatex==1.4.1 + - pytest==6.2.5 - pytz==2021.1 - pyyaml==5.4 - quinine==0.3.0 diff --git a/environments/environment-gpu.yaml b/environments/environment-gpu.yaml index fac558f3..6a8fb5da 100644 --- a/environments/environment-gpu.yaml +++ b/environments/environment-gpu.yaml @@ -183,6 +183,7 @@ dependencies: - pycodestyle==2.6.0 - pyflakes==2.2.0 - pylatex==1.4.1 + - pytest==6.2.5 - pytz==2021.1 - pyyaml==5.4 - quinine==0.3.0 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..e92ddf60 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,31 @@ +[mypy] +disable_error_code=override + +# do not follow imports (except for ones found in typeshed) +ignore_missing_imports = True +#Ignore errors for third parties +ignore_errors = True +follow_imports = silent + +# treat Optional per PEP 484 +strict_optional = False + +warn_unused_configs = True +warn_redundant_casts = True +# ensure all execution paths are returning +warn_no_return= True +warn_unreachable = True +allow_redefinition = True + +show_error_codes = True +check_untyped_defs = True + + +files= + src, + tests, + train.py +python_version = 3.6 + +[mypy-src.*] +ignore_errors = False diff --git a/src/args/training_args.py b/src/args/training_args.py index 48e0dc8d..a6e55f93 100644 --- a/src/args/training_args.py +++ b/src/args/training_args.py @@ -26,7 +26,7 @@ def get_training_arguments( nodes: int = 1, gpus_per_node: int = 8, ) -> TrainingArguments: - """ Initialize Training Arguments from Quinfig and Runtime-Defined Variables. """ + """Initialize Training Arguments from Quinfig and Runtime-Defined Variables.""" # `quinfig_args` already contains some default training arguments --> we'll be overwriting/adding to the Dict # =>> a `Munch` is a subclass of Dictionary that supports attribute style access diff --git a/src/core/callbacks.py b/src/core/callbacks.py index 30fdc0cf..620d47a1 100644 --- a/src/core/callbacks.py +++ b/src/core/callbacks.py @@ -8,7 +8,7 @@ import os import time from bisect import bisect_left -from typing import Dict, List +from typing import Dict, List, Optional import jsonlines import torch @@ -41,7 +41,7 @@ def rewrite_logs(d: Dict[str, float]) -> Dict[str, float]: class CustomWandbCallback(WandbCallback): - """ Custom Weights and Biases Callback used by Mistral for logging information from the Huggingface Trainer. """ + """Custom Weights and Biases Callback used by Mistral for logging information from the Huggingface Trainer.""" def __init__( self, @@ -70,14 +70,15 @@ def __init__( self.group, self.resume, self.resume_run_id, self.wandb_dir = group, resume, resume_run_id, wandb_dir # Timers - self.within_time, self.between_time = None, None + self.within_time: Optional[float] = None + self.between_time: Optional[float] = None def _append_jsonl(self, data) -> None: with jsonlines.open(self.json_file, mode="a") as writer: writer.write(data) def _log_memory(self, state, prefix="train_info"): - """ Simple method to log memory usage at the end of every training batch. """ + """Simple method to log memory usage at the end of every training batch.""" if state.is_world_process_zero and torch.cuda.is_available(): memory_usage = { f"{prefix}/memory_allocated": torch.cuda.memory_allocated() / 2 ** 20, @@ -254,7 +255,7 @@ def on_train_begin( eval_dataloader=None, **kwargs, ): - """ Calls wandb.init, we add additional arguments to that call using this method. """ + """Calls wandb.init, we add additional arguments to that call using this method.""" # Pass in additional keyword arguments to the wandb.init call as kwargs super().on_train_begin( @@ -325,7 +326,7 @@ def on_log( class CustomCheckpointCallback(TrainerCallback): - """ Custom Checkpoint Callback used by Mistral for Saving Checkpoints at different frequencies. """ + """Custom Checkpoint Callback used by Mistral for Saving Checkpoints at different frequencies.""" def __init__(self, frequencies: List[List[int]]): super(CustomCheckpointCallback, self).__init__() @@ -337,7 +338,7 @@ def __init__(self, frequencies: List[List[int]]): assert all(i < j for i, j in zip(self.until, self.until[1:])), "Frequency `until_step` not increasing!" def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - """ Borrow Checkpoint Logic from `DefaultFlowCallback` to decide when to checkpoint. """ + """Borrow Checkpoint Logic from `DefaultFlowCallback` to decide when to checkpoint.""" # Save (note we explicitly save checkpoint-0 in `train.py`, so no need to do it here) c = state.global_step diff --git a/src/core/trainer.py b/src/core/trainer.py index 2a83c416..aed2384d 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -6,14 +6,13 @@ import collections import logging import time -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data.dataset import Dataset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments from transformers.data.data_collator import DataCollator from transformers.file_utils import is_datasets_available @@ -41,6 +40,9 @@ class OnlineBenchmarkTrainer(Trainer): Overrides `evaluate` to trigger eval on each online dataset. """ + control: Any + _globalstep_last_logged: int + def __init__( self, model: AutoModelForCausalLM, @@ -161,7 +163,7 @@ def evaluate( return metrics def single_dataset_eval(self, dataset_name: str, dataset: Dataset, metric_key_prefix: str) -> Dict[str, float]: - """ Run Perplexity Evaluation on a Single Dataset. """ + """Run Perplexity Evaluation on a Single Dataset.""" custom_metric_key_prefix = f"{metric_key_prefix}_{dataset_name}" if dataset is not None and not isinstance(dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") @@ -223,7 +225,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: else: if self.args.world_size <= 1: - return RandomSampler(self.train_dataset) + return DistributedSampler(self.train_dataset, num_replicas=1, rank=0, seed=self.args.seed) elif ( self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] and not self.args.dataloader_drop_last @@ -240,7 +242,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: else: # @Mercury =>> Critical Change :: Pass seed to Distributed Sampler to randomize Data Order! return DistributedSampler( - self.train_dataset, + self.train_dataset, # type: ignore num_replicas=self.args.world_size, rank=self.args.process_index, seed=self.args.seed, diff --git a/src/corpora/auto.py b/src/corpora/auto.py index 7e30c760..44e4a0cd 100644 --- a/src/corpora/auto.py +++ b/src/corpora/auto.py @@ -7,7 +7,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Dict, List +from typing import Dict, Iterable, List import datasets from transformers import BatchEncoding, PreTrainedTokenizer @@ -30,7 +30,7 @@ def get_auto_dataset( stride: int = -1, ignore_train: bool = False, ) -> datasets.DatasetDict: - """ Run basic tokenization and grouping to turn a Hugging Face Dataset (via `datasets`) into a torch.Dataset. """ + """Run basic tokenization and grouping to turn a Hugging Face Dataset (via `datasets`) into a torch.Dataset.""" # Sanity check on input args stride = seq_len if stride < 0 else stride @@ -84,9 +84,9 @@ def tokenize(examples: Dict[str, List[str]]) -> BatchEncoding: ) # Finally, actually run chunking (collapse multiple sequences into a giant document to read `seq_len` chunks from) - def group(examples: Dict[str, List[int]]) -> Dict[str, List[int]]: + def group(examples: Dict[str, Iterable[List[int]]]) -> Dict[str, List[List[int]]]: # Concatenate all the Texts - concatenated = {k: sum(examples[k], []) for k in examples.keys()} + concatenated: Dict[str, List[int]] = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated[list(examples.keys())[0]]) # Drop the "very last" bit of the dataset that doesn't fit into block size... diff --git a/src/models/auto_clm.py b/src/models/auto_clm.py index c6e008a8..6c55f10c 100644 --- a/src/models/auto_clm.py +++ b/src/models/auto_clm.py @@ -56,7 +56,7 @@ def get_auto_clm_tokenizer( upcast_attn: bool = True, initial_weights: str = None, ) -> Tuple[AutoModelForCausalLM, PreTrainedTokenizer]: - """ Download/Load AutoConfig and Instantiate Corresponding Model and Tokenizer. """ + """Download/Load AutoConfig and Instantiate Corresponding Model and Tokenizer.""" # Create Configuration if "gpt2" in model_id and model_configs: diff --git a/src/models/mistral_gpt2.py b/src/models/mistral_gpt2.py index d5803693..79343c2e 100644 --- a/src/models/mistral_gpt2.py +++ b/src/models/mistral_gpt2.py @@ -7,6 +7,7 @@ Reference: https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt2/modeling_gpt2.py """ import logging +from typing import Tuple import torch import torch.nn as nn @@ -169,10 +170,10 @@ def forward( output_shape = input_shape + (hidden_states.size(-1),) - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None + presents: Tuple = () if use_cache else None + all_self_attentions: Tuple = () if output_attentions else None + all_cross_attentions: Tuple = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states: Tuple = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -211,7 +212,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch.utils.checkpoint.checkpoint( # type:ignore[attr-defined] create_custom_forward(block), hidden_states, None, @@ -312,7 +313,7 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions= if self.reorder_attn: # Preallocate Scaled Dot-Product Tensor - w = torch.empty( + w = torch.empty( # type: ignore bsz * num_heads, seq_len, seq_len, @@ -373,6 +374,12 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions= w = nn.Softmax(dim=-1)(w) + # verify upcasting is happening + if self.upcast_attn: + if w.dtype != torch.float32: + overwatch.critical("Upcasting Error. w does not have dtype torch.float32") + raise RuntimeError("Upcasting Error. w does not have dtype torch.float32") + # @MERCURY =>> Downcast (if necessary) back to V dtype (fp16 if mixed-precision)! # Note: This is a No-Op if Upcasting is disabled... w = w.type(v.dtype) @@ -383,7 +390,7 @@ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions= if head_mask is not None: w = w * head_mask - outputs = (torch.matmul(w, v),) + outputs: Tuple = (torch.matmul(w, v),) if output_attentions: outputs += (w,) return outputs diff --git a/src/util/paths.py b/src/util/paths.py index 8fec0eba..729e9f3e 100644 --- a/src/util/paths.py +++ b/src/util/paths.py @@ -42,6 +42,6 @@ def create_paths(run_id: str, model: str, run_dir: str, cache_dir: str) -> Dict[ def set_permissions(paths: Dict[str, Path]) -> None: - """ Recursively call `os.chmod(775) recursively for the given paths. """ + """Recursively call `os.chmod(775) recursively for the given paths.""" for p in paths: os.system(f"chmod -R 775 {paths[p]} >/dev/null 2>&1") diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..c697d347 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,15 @@ +# Run Tests + +Set this environment variable to a working directory that can store the Hugging Face cache and checkpoints created by the tests: + +```bash +export MISTRAL_TEST_HOME=/path/to/mistral-test-working-dir +``` + +From the `tests` directory, run this command to run tests in single node/single GPU mode: + +```bash +export CUDA_VISIBLE_DEVICES=0 +cd tests +pytest +``` diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..2abe36c9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,144 @@ +import inspect +import os +import re +import shutil +import subprocess +import sys +import traceback +from unittest.mock import patch + +import psutil + +from src.core.trainer import OnlineBenchmarkTrainer +from train import train + + +MISTRAL_TEST_DIR = os.getenv("MISTRAL_TEST_DIR") + +# standard utils + + +def to_cl_args(args_dict): + """ + Create a list of cl args from a dictionary + """ + args_list = [] + for k, v in args_dict.items(): + args_list.append(f"--{k}") + args_list.append(v) + return args_list + + +# deepspeed utils + + +def launched_by_deepspeed(): + """ + Determine if this process has been launched by deepspeed. + """ + parent = psutil.Process(os.getppid()) + return "deepspeed.launcher.launch" in parent.cmdline() + + +DEEPSPEED_MODE = launched_by_deepspeed() + + +def am_first_deepspeed_child(): + """ + Check if this is the first deepspeed child. + """ + if DEEPSPEED_MODE: + parent = psutil.Process(os.getppid()) + children = parent.children() + return os.getpid() == children[0].pid if children else False + else: + return False + + +def deepspeed_launch_info(): + """ + Get info about number of nodes/gpus used by deepspeed. + """ + grandparent = psutil.Process(os.getppid()).parent() + num_nodes = grandparent.cmdline()[grandparent.cmdline().index("--num_nodes") + 1] + num_gpus = grandparent.cmdline()[grandparent.cmdline().index("--num_gpus") + 1] + return {"nodes": int(num_nodes), "gpus": int(num_gpus)} + + +def deepspeedify(cl_args_dict): + """ + Alter standard test args to have deepspeed info + """ + info = deepspeed_launch_info() + cl_args_dict["nproc_per_node"] = str(info["gpus"]) + cl_args_dict["nnodes"] = str(info["nodes"]) + # cl_args_dict["training_arguments.deepspeed"] = "conf/deepspeed/z2-small-conf.json" + cl_args_dict["training_arguments.deepspeed"] = "conf/deepspeed/z1-conf.json" + + +def run_train_process(cl_args_dict, runs_dir, run_id, use_deepspeed=DEEPSPEED_MODE) -> OnlineBenchmarkTrainer: + """ + Run training with given cl args and run dir. + """ + # clear training dir + cl_args_dict["artifacts.run_dir"] = runs_dir + cl_args_dict["run_id"] = run_id + if use_deepspeed: + deepspeedify(cl_args_dict) + cl_args = [""] + to_cl_args(cl_args_dict) + run_id_dir = f"{runs_dir}/{run_id}" + if not use_deepspeed or am_first_deepspeed_child(): + print(f"Removing {run_id_dir}...") + shutil.rmtree(run_id_dir) if os.path.exists(run_id_dir) else None + # log cl args used + print(f"Using following command line args for training: {cl_args}") + with patch.object(sys, "argv", cl_args): + # run main training process + trainer = train() + return trainer + + +def get_test_functions(): + """ + Return all test functions in this module. + """ + all_test_functions = [ + (name, obj) + for name, obj in inspect.getmembers(sys.modules["__main__"]) + if (inspect.isfunction(obj) and name.startswith("test") and obj.__module__ == "__main__") + ] + return all_test_functions + + +def run_tests(): + """ + Run each function, catch and report AssertionError's + """ + if DEEPSPEED_MODE and not am_first_deepspeed_child(): + return + test_functions = get_test_functions() + passing_tests = [] + failing_tests = [] + assertion_errors = [] + print("Running tests:") + for (name, test_function) in test_functions: + print("") + print(name) + try: + test_function() + passing_tests.append(name) + except AssertionError as e: + failing_tests.append(name) + assertion_errors.append((e, traceback.format_exc())) + print("") + print("Test report:") + print(f"{len(passing_tests)} passed, {len(failing_tests)} failed") + print("") + print("Failing tests:") + for test, error in zip(failing_tests, assertion_errors): + print("") + print(f"{test}") + print(error[1]) + print(error[0]) + if len(failing_tests) > 0: + sys.exit(1) diff --git a/tests/conf/datasets/wikitext103.yaml b/tests/conf/datasets/wikitext103.yaml new file mode 100644 index 00000000..487493b0 --- /dev/null +++ b/tests/conf/datasets/wikitext103.yaml @@ -0,0 +1,13 @@ +# wikitext103.yaml +# Configuration for WikiText-103 Dataset (https://huggingface.co/datasets/wikitext). +--- +dataset: + id: wikitext + name: wikitext-103-raw-v1 + validation_ratio: null + + # Number of Preprocessing Workers + num_proc: 4 + + # Number of Evaluation Preprocessing Workers + eval_num_proc: 4 diff --git a/tests/conf/datasets/wikitext2.yaml b/tests/conf/datasets/wikitext2.yaml new file mode 100644 index 00000000..58dfb9e2 --- /dev/null +++ b/tests/conf/datasets/wikitext2.yaml @@ -0,0 +1,13 @@ +# wikitext2.yaml +# Configuration for WikiText-2 Dataset (https://huggingface.co/datasets/wikitext). +--- +dataset: + id: wikitext + name: wikitext-2-raw-v1 + validation_ratio: null + + # Number of Preprocessing Workers + num_proc: 4 + + # Number of Evaluation Preprocessing Workers + eval_num_proc: 4 diff --git a/tests/conf/deepspeed/z1-conf.json b/tests/conf/deepspeed/z1-conf.json new file mode 100644 index 00000000..c1be93ff --- /dev/null +++ b/tests/conf/deepspeed/z1-conf.json @@ -0,0 +1,34 @@ +{ + "optimizer": { + "type": "AdamW", + "params": { + "lr": 0.0006, + "betas": [ + 0.9, + 0.95 + ], + "eps": 1e-8, + "weight_decay": 0.1 + } + }, + + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": 400000, + "warmup_max_lr": 0.0006, + "warmup_num_steps": 4000 + } + }, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "overlap_comm": true, + "contiguous_gradients": true, + "cpu_offload": false + } +} diff --git a/tests/conf/models/gpt2-micro.json b/tests/conf/models/gpt2-micro.json new file mode 100644 index 00000000..d844f2c5 --- /dev/null +++ b/tests/conf/models/gpt2-micro.json @@ -0,0 +1,35 @@ +{ + "activation_function": "gelu_new", + "architectures": [ + "MistralGPT2LMHeadModel" + ], + "attn_pdrop": 0.0, + "bos_token_id": 50256, + "embd_pdrop": 0.0, + "eos_token_id": 50256, + "gradient_checkpointing": false, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 256, + "n_embd": 768, + "n_head": 2, + "n_inner": null, + "n_layer": 2, + "n_positions": 256, + "resid_pdrop": 0.0, + "summary_activation": null, + "summary_first_dropout": 0.0, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "transformers_version": "4.5.0", + "use_cache": false, + "vocab_size": 50257 +} diff --git a/tests/conf/models/gpt2-micro.yaml b/tests/conf/models/gpt2-micro.yaml new file mode 100644 index 00000000..50eb9408 --- /dev/null +++ b/tests/conf/models/gpt2-micro.yaml @@ -0,0 +1,28 @@ +# gpt2-micro-config.yaml +# Configuration for the GPT-2 Micro Model. +--- +model: + id: "gpt2-small" + + # Boolean whether to use Gradient Checkpointing to save GPU Memory at the expense of runtime + gradient_checkpointing: false + + # Add Gradient Checkpointing Every `gc_checkpoint_every` Transformer blocks + # > Checkpoints = (# layers / `gc_checkpoint_every`) Blocks + gc_checkpoint_every: -1 + + # Boolean whether to use the pre-existing Hugging Face AutoTokenizer (or train a new one from scratch) + pretrained_tokenizer: true + + # Sequence Length + seq_len: 256 + + # Stability -- Upcasting and Scaled Dot-Product Reordering + reorder_attn: true + upcast_attn: true + + # Initialize Weights from File + initial_weights: null + + # Configure Model From File + config_path: conf/models/gpt2-micro.json diff --git a/tests/conf/models/gpt2-small.yaml b/tests/conf/models/gpt2-small.yaml new file mode 100644 index 00000000..d7c8024c --- /dev/null +++ b/tests/conf/models/gpt2-small.yaml @@ -0,0 +1,25 @@ +# gpt2-small-config.yaml +# Configuration for the GPT-2 Small Model. +--- +model: + id: "gpt2-small" + + # Boolean whether to use Gradient Checkpointing to save GPU Memory at the expense of runtime + gradient_checkpointing: false + + # Add Gradient Checkpointing Every `gc_checkpoint_every` Transformer blocks + # > Checkpoints = (# layers / `gc_checkpoint_every`) Blocks + gc_checkpoint_every: -1 + + # Boolean whether to use the pre-existing Hugging Face AutoTokenizer (or train a new one from scratch) + pretrained_tokenizer: true + + # Sequence Length + seq_len: 512 + + # Stability -- Upcasting and Scaled Dot-Product Reordering + reorder_attn: true + upcast_attn: true + + # Initialize Weights from File + initial_weights: null diff --git a/tests/conf/train-diff.yaml b/tests/conf/train-diff.yaml new file mode 100644 index 00000000..ff918970 --- /dev/null +++ b/tests/conf/train-diff.yaml @@ -0,0 +1,61 @@ +# hello-world.yaml +# Full Mistral GPT-2 Small Training Config, currently working with the OpenWebText Dataset, GPT-2 Small Architecture, +# and full batch size (512). Runs with DeepSpeed ZeRO-2, with a per-device BSZ of 16. +# +# Inheritance and core paths can all be overridden from the command line or by re-writing these files. +--- +# Inherit Dataset, Tokenization, Model, and Training Details +inherit: + - datasets/wikitext103.yaml + - models/gpt2-micro.yaml + - trainers/gpt2-small-diff.yaml + +# Run ID -- make sure to override! +run_id: null + +# Weights & Biases +wandb: hello-world +group: gpt2-small + +# Artifacts & Caching +artifacts: + cache_dir: /nlp/scr/jebolton/mistral-hello-world/artifacts + run_dir: /nlp/scr/jebolton/mistral-hello-world/runs + +# Save Effective Batch Size for Easy Handling ==> Main Code asserts infra + training_config results in this! +effective_bsz: 16 + +# Resume from Checkpoint +resume: false +resume_checkpoint: null + +# List of frequencies at which to save checkpoints, provided as a list of two-element tuples: +# - Frequency (`freq`) at which to save checkpoints (# steps) +# - Bound (`until`) on global step for given frequency (checkpoint every `freq` steps until global step = `until`) +checkpoint_frequency: + - [10, 100] + - [50, 2000] + - [100, 20000] + - [1000, 400000] + +# `torch.distributed` Default Infra Parameters -- to be overwritten by call to `torch.distributed.launch` +local_rank: -1 +nnodes: -1 +nproc_per_node: -1 + +# DeepSpeed Default Infra Parameters -- to be overwritten by call to `DeepSpeed` +num_gpus: -1 +num_nodes: -1 +world_size: -1 + +# Logging Parameters -- 10 = DEBUG, 20 = INFO, 30 = WARNING, 40 = ERROR, 50 = CRITICAL +log_level: 20 + +# Random Seed +seed: 40 + +online_eval: + stride: 256 + +run_training: false +run_final_eval: false diff --git a/tests/conf/train.yaml b/tests/conf/train.yaml new file mode 100644 index 00000000..ca62cba9 --- /dev/null +++ b/tests/conf/train.yaml @@ -0,0 +1,59 @@ +# hello-world.yaml +# Full Mistral GPT-2 Small Training Config, currently working with the OpenWebText Dataset, GPT-2 Small Architecture, +# and full batch size (512). Runs with DeepSpeed ZeRO-2, with a per-device BSZ of 16. +# +# Inheritance and core paths can all be overridden from the command line or by re-writing these files. +--- +# Inherit Dataset, Tokenization, Model, and Training Details +inherit: + - datasets/wikitext2.yaml + - models/gpt2-micro.yaml + - trainers/gpt2-small.yaml + +# Run ID -- make sure to override! +run_id: null + +# Weights & Biases +wandb: hello-world +group: gpt2-small + +# Artifacts & Caching +artifacts: + cache_dir: + run_dir: + +# Save Effective Batch Size for Easy Handling ==> Main Code asserts infra + training_config results in this! +effective_bsz: 16 + +# Resume from Checkpoint +resume: false +resume_checkpoint: null + +# List of frequencies at which to save checkpoints, provided as a list of two-element tuples: +# - Frequency (`freq`) at which to save checkpoints (# steps) +# - Bound (`until`) on global step for given frequency (checkpoint every `freq` steps until global step = `until`) +checkpoint_frequency: + - [2, 18] + - [10, 100] + - [50, 2000] + - [100, 20000] + - [1000, 400000] + +# `torch.distributed` Default Infra Parameters -- to be overwritten by call to `torch.distributed.launch` +local_rank: -1 +nnodes: -1 +nproc_per_node: -1 + +# DeepSpeed Default Infra Parameters -- to be overwritten by call to `DeepSpeed` +num_gpus: -1 +num_nodes: -1 +world_size: -1 + +# Logging Parameters -- 10 = DEBUG, 20 = INFO, 30 = WARNING, 40 = ERROR, 50 = CRITICAL +log_level: 20 + +# Random Seed +seed: 21 + +online_eval: + stride: 256 diff --git a/tests/conf/trainers/gpt2-small-diff.yaml b/tests/conf/trainers/gpt2-small-diff.yaml new file mode 100644 index 00000000..43d6bbfa --- /dev/null +++ b/tests/conf/trainers/gpt2-small-diff.yaml @@ -0,0 +1,67 @@ +# gpt2-small.yaml +# Trainer config for Full GPT-2 Small, with the full fixed batch size of 512 (with gradient accumulation). +# This contract exactly follows that of HF.TrainingArguments so we can pass as a simple **kwargs -- make sure this +# continues to stay valid! +# Reference: https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments +--- +training_arguments: + # Overwrite from Top-Level Config + output_dir: null + + # Generally sticks to order from HF.TrainingArguments() Docs, skipping over sane defaults/implicitly set args... + do_train: true + evaluation_strategy: steps + + # Set these based on GPU RAM/your available hardware + per_device_train_batch_size: 8 + per_device_eval_batch_size: 16 + + # We set this dynamically based on DDP Computation [steps = effective_batch / (per_gpu_batch * gpus * nodes)] + gradient_accumulation_steps: null + + # For Online Evaluation, only keep around the Losses + prediction_loss_only: true + + # Learning Rate & Optimization Parameters, assumes AdamW + learning_rate: 0.0006 + weight_decay: 0.2 + adam_beta1: 0.7 + adam_beta2: 0.3 + adam_epsilon: 1.0e-8 + + # Gradient Norm + max_grad_norm: 2.0 + + # Maximum Training Steps (Overrides epochs!) + max_steps: 100000 + + # LR Scheduling Parameters -- Warmup Steps should be 1% of total steps (Could use ratio) + lr_scheduler_type: linear # Cosine not supported if we want to use DeepSpeed Optimizers (gets overwritten!) + warmup_steps: 4000 + + # Logging Parameters -- Logging Directory (Tensorboard - is this necessary?) should be Overwritten at Runtime! + run_name: null + logging_dir: null + logging_first_step: true + logging_steps: 50 + + # Saving and Evaluation Steps + eval_steps: 1000 + save_steps: 1000 + + # Resume Behavior --> ignore "full determinism" on resume (saves time for debugging) + ignore_data_skip: false + + # Seeds -- Should be Overwritten at Runtime! + seed: null + + ### Optimization -- Precision, DeepSpeed, and FairScale Parameters -- all off for `simple` config + fp16: true + sharded_ddp: null + deepspeed: null + + # Dataloader Parallelism + dataloader_num_workers: 4 + + # Should be overwritten from the Top-Level Config or CLI! + local_rank: null diff --git a/tests/conf/trainers/gpt2-small.yaml b/tests/conf/trainers/gpt2-small.yaml new file mode 100644 index 00000000..b282a3e3 --- /dev/null +++ b/tests/conf/trainers/gpt2-small.yaml @@ -0,0 +1,67 @@ +# gpt2-small.yaml +# Trainer config for Full GPT-2 Small, with the full fixed batch size of 512 (with gradient accumulation). +# This contract exactly follows that of HF.TrainingArguments so we can pass as a simple **kwargs -- make sure this +# continues to stay valid! +# Reference: https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments +--- +training_arguments: + # Overwrite from Top-Level Config + output_dir: null + + # Generally sticks to order from HF.TrainingArguments() Docs, skipping over sane defaults/implicitly set args... + do_train: true + evaluation_strategy: steps + + # Set these based on GPU RAM/your available hardware + per_device_train_batch_size: 8 + per_device_eval_batch_size: 16 + + # We set this dynamically based on DDP Computation [steps = effective_batch / (per_gpu_batch * gpus * nodes)] + gradient_accumulation_steps: null + + # For Online Evaluation, only keep around the Losses + prediction_loss_only: true + + # Learning Rate & Optimization Parameters, assumes AdamW + learning_rate: 0.0006 + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1.0e-8 + + # Gradient Norm + max_grad_norm: 1.0 + + # Maximum Training Steps (Overrides epochs!) + max_steps: 400000 + + # LR Scheduling Parameters -- Warmup Steps should be 1% of total steps (Could use ratio) + lr_scheduler_type: linear # Cosine not supported if we want to use DeepSpeed Optimizers (gets overwritten!) + warmup_steps: 4000 + + # Logging Parameters -- Logging Directory (Tensorboard - is this necessary?) should be Overwritten at Runtime! + run_name: null + logging_dir: null + logging_first_step: true + logging_steps: 50 + + # Saving and Evaluation Steps + eval_steps: 1000 + save_steps: 1000 + + # Resume Behavior --> ignore "full determinism" on resume (saves time for debugging) + ignore_data_skip: false + + # Seeds -- Should be Overwritten at Runtime! + seed: null + + ### Optimization -- Precision, DeepSpeed, and FairScale Parameters -- all off for `simple` config + fp16: true + sharded_ddp: null + deepspeed: null + + # Dataloader Parallelism + dataloader_num_workers: 4 + + # Should be overwritten from the Top-Level Config or CLI! + local_rank: null diff --git a/tests/test_args.py b/tests/test_args.py new file mode 100644 index 00000000..f7ab0e76 --- /dev/null +++ b/tests/test_args.py @@ -0,0 +1,51 @@ +from tests import MISTRAL_TEST_DIR, run_tests, run_train_process + + +# paths +CACHE_DIR = f"{MISTRAL_TEST_DIR}/artifacts" +RUNS_DIR = f"{MISTRAL_TEST_DIR}/runs" + +TRAIN_ARGS = { + "nnodes": "1", + "nproc_per_node": "1", + "config": "conf/train.yaml", + "training_arguments.fp16": "true", + "training_arguments.per_device_train_batch_size": "1", + "artifacts.cache_dir": CACHE_DIR, + "log_level": "50", + "run_training": "false", + "run_final_eval": "false", +} + +trainer_w_train = run_train_process(cl_args_dict=TRAIN_ARGS, runs_dir=RUNS_DIR, run_id="train_args_test") + +TRAIN_ARGS_DIFF = { + "nnodes": "1", + "nproc_per_node": "1", + "config": "conf/train-diff.yaml", + "training_arguments.fp16": "true", + "training_arguments.per_device_train_batch_size": "1", + "artifacts.cache_dir": CACHE_DIR, + "log_level": "50", + "run_training": "false", + "run_final_eval": "false", +} + +trainer_w_train_diff = run_train_process( + cl_args_dict=TRAIN_ARGS_DIFF, runs_dir=RUNS_DIR, run_id="train_args_diff_test" +) + + +def test_train_args() -> None: + assert trainer_w_train.args.weight_decay == 0.1 + assert trainer_w_train.args.adam_beta1 == 0.9 + assert trainer_w_train.args.adam_beta2 == 0.95 + assert trainer_w_train.args.max_grad_norm == 1.0 + assert trainer_w_train_diff.args.weight_decay == 0.2 + assert trainer_w_train_diff.args.adam_beta1 == 0.7 + assert trainer_w_train_diff.args.adam_beta2 == 0.3 + assert trainer_w_train_diff.args.max_grad_norm == 2.0 + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 00000000..e58ddd57 --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,107 @@ +import os + +import torch + +from tests import MISTRAL_TEST_DIR, run_tests, run_train_process + + +# common paths and resources for tests + +# paths +CACHE_DIR = f"{MISTRAL_TEST_DIR}/artifacts" +RUNS_DIR = f"{MISTRAL_TEST_DIR}/runs" +RUN_ID = "checkpoint_test" +RUN_ID_DIR = f"{RUNS_DIR}/{RUN_ID}" +LAST_CHECKPOINT = "checkpoint-18" + +# run training processes for tests +TRAIN_ARGS = { + "nnodes": "1", + "nproc_per_node": "1", + "config": "conf/train.yaml", + "training_arguments.fp16": "true", + "training_arguments.max_steps": "19", + "training_arguments.per_device_train_batch_size": "1", + "artifacts.cache_dir": CACHE_DIR, + "log_level": "20", + "effective_bsz": "16", + "run_final_eval": "false", +} + +trainer_after_training = run_train_process(cl_args_dict=TRAIN_ARGS, runs_dir=RUNS_DIR, run_id=RUN_ID) + +RESTART_ARGS = { + "nnodes": "1", + "nproc_per_node": "1", + "config": "conf/train.yaml", + "training_arguments.fp16": "true", + "training_arguments.max_steps": "1", + "training_arguments.per_device_train_batch_size": "1", + "resume": "True", + "resume_checkpoint": f"{RUN_ID_DIR}/{LAST_CHECKPOINT}", + "artifacts.cache_dir": CACHE_DIR, + "log_level": "20", + "effective_bsz": "16", + "run_final_eval": "false", +} + +trainer_after_restart = run_train_process(cl_args_dict=RESTART_ARGS, runs_dir=RUNS_DIR, run_id=RUN_ID + "-restart") + + +def test_checkpoint_weights() -> None: + """ + Test weights of a checkpointed model match the true weights. + """ + model = trainer_after_training.model + loaded_model = trainer_after_restart.model + loaded_model.to(torch.device("cuda")) + assert model.state_dict().keys() == loaded_model.state_dict().keys() + for key in model.state_dict().keys(): + assert torch.equal(model.state_dict()[key], loaded_model.state_dict()[key]) + + +def test_checkpoint_forward_pass() -> None: + """ + Test that loaded model correctly calculate forward pass + """ + model = trainer_after_training.model + loaded_model = trainer_after_restart.model + loaded_model.to(torch.device("cuda")) + train_dataloader = trainer_after_training.get_train_dataloader() + inputs = next(iter(train_dataloader)) + inputs = trainer_after_training._prepare_inputs(inputs) + assert model.state_dict().keys() == loaded_model.state_dict().keys() + for key in model.state_dict().keys(): + assert torch.equal(model.state_dict()[key], loaded_model.state_dict()[key]) + # run forward with loaded model + loaded_model.eval() + outputs_loaded = loaded_model(**inputs) + # run forward with original model + model.eval() + outputs = model(**inputs) + assert torch.equal(outputs["logits"], outputs_loaded["logits"]), ( + f"original: {outputs['logits']} dtype: {outputs['logits'].dtype}, loaded: {outputs_loaded['logits']} dtype:" + f" {outputs['logits'].dtype}" + ) + + +def test_checkpoint_frequency() -> None: + """ + Test checkpointing happening at expected frequency + """ + assert not os.path.exists(f"{RUN_ID_DIR}/checkpoint-1") + assert os.path.exists(f"{RUN_ID_DIR}/checkpoint-2") + assert not os.path.exists(f"{RUN_ID_DIR}/checkpoint-3") + + +def test_restart_batch_order() -> None: + """ + Test batch order is consistent when restarting + """ + original_indices = list(iter(trainer_after_training.get_train_dataloader().sampler)) + after_restart_indices = list(iter(trainer_after_restart.get_train_dataloader().sampler)) + assert original_indices == after_restart_indices + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_fp.py b/tests/test_fp.py new file mode 100644 index 00000000..5dddedb7 --- /dev/null +++ b/tests/test_fp.py @@ -0,0 +1,35 @@ +from tests import MISTRAL_TEST_DIR, run_tests, run_train_process + + +# common paths and resources for tests + +# paths +CACHE_DIR = f"{MISTRAL_TEST_DIR}/artifacts" +RUNS_DIR = f"{MISTRAL_TEST_DIR}/runs" +RUN_ID = "upcasting_test" +RUN_ID_DIR = f"{RUNS_DIR}/{RUN_ID}" +LAST_CHECKPOINT = "checkpoint-18" + +# run training processes for tests +TRAIN_ARGS = { + "nnodes": "1", + "nproc_per_node": "1", + "config": "conf/train.yaml", + "training_arguments.fp16": "true", + "training_arguments.max_steps": "4", + "artifacts.cache_dir": CACHE_DIR, + "run_training": "true", + "run_final_eval": "false", + "log_level": "50", +} + + +def test_upcasting() -> None: + """ + Run training with upcasting + """ + run_train_process(cl_args_dict=TRAIN_ARGS, runs_dir=RUNS_DIR, run_id=RUN_ID) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_seed.py b/tests/test_seed.py new file mode 100644 index 00000000..baf7187f --- /dev/null +++ b/tests/test_seed.py @@ -0,0 +1,167 @@ +import torch + +from tests import MISTRAL_TEST_DIR, run_tests, run_train_process + + +# paths +CACHE_DIR = f"{MISTRAL_TEST_DIR}/artifacts" +RUNS_DIR = f"{MISTRAL_TEST_DIR}/runs" +RUN_ID = "train_args_test" +RUN_ID_DIR = f"{RUNS_DIR}/{RUN_ID}" + +# set up different trainers to see initialization differences +TRAIN_ARGS_SEED_7 = { + "nnodes": "1", + "nproc_per_node": "1", + "config": "conf/train.yaml", + "training_arguments.fp16": "true", + "training_arguments.per_device_train_batch_size": "1", + "artifacts.cache_dir": CACHE_DIR, + "seed": "7", + "log_level": "50", + "run_training": "false", + "run_final_eval": "false", +} + +trainer_seed_7 = run_train_process(cl_args_dict=TRAIN_ARGS_SEED_7, runs_dir=RUNS_DIR, run_id="trainer_seed_7") + +TRAIN_ARGS_SEED_10 = dict(TRAIN_ARGS_SEED_7) +TRAIN_ARGS_SEED_10["seed"] = "10" +trainer_seed_10 = run_train_process(cl_args_dict=TRAIN_ARGS_SEED_10, runs_dir=RUNS_DIR, run_id="trainer_seed_10") + + +def is_randomized(key): + """ + Helper to determine if the key in the state_dict() is a set of parameters that is randomly initialized. + Some weights are not randomly initalized and won't be afffected by seed, particularly layer norm + weights and biases, and bias terms in general. + """ + # regexes for components that are not randomized + print(key) + if key.endswith("bias") or "ln" in key: + return False + else: + return True + + +def test_weight_initializations() -> None: + assert trainer_seed_7.model.state_dict().keys() == trainer_seed_10.model.state_dict().keys() + for key in trainer_seed_7.model.state_dict().keys(): + if is_randomized(key): + assert not torch.equal( + trainer_seed_7.model.state_dict()[key], trainer_seed_10.model.state_dict()[key] + ), f"weights are the same for {key}" + + +def test_data_order() -> None: + seed_7_dataloader = trainer_seed_7.get_train_dataloader() + seed_10_dataloader = trainer_seed_10.get_train_dataloader() + seed_7_indices, seed_10_indices = list(iter(seed_7_dataloader.sampler)), list(iter(seed_10_dataloader.sampler)) + expected_indices = [ + 7485, + 6448, + 8289, + 7940, + 6492, + 4866, + 1722, + 1303, + 3568, + 7713, + 4597, + 3294, + 7178, + 2517, + 8770, + 8208, + 90, + 4594, + 4487, + 5002, + 2784, + 4846, + 6457, + 4210, + 1510, + 2230, + 8074, + 1846, + 753, + 3613, + 3354, + 8174, + 6577, + 6422, + 2463, + 670, + 8784, + 8659, + 2515, + 647, + 6654, + 5255, + 8623, + 7172, + 679, + 4060, + 4177, + 2159, + 7638, + 3163, + 468, + 2689, + 5817, + 8100, + 5736, + 8081, + 3993, + 7968, + 3549, + 7995, + 596, + 370, + 6044, + 1640, + 1693, + 7685, + 3544, + 5806, + 1887, + 692, + 5526, + 4601, + 3042, + 8700, + 222, + 1601, + 4908, + 5576, + 4823, + 7853, + 6892, + 5932, + 7890, + 2599, + 6431, + 2136, + 8601, + 964, + 2214, + 3320, + 1593, + 5543, + 5599, + 1694, + 3991, + 3595, + 4128, + 5573, + 4720, + 4600, + ] + actual_indices = [seed_10_indices.index(seed_7_indices[i]) for i in range(0, 100)] + assert expected_indices == actual_indices, actual_indices + + +if __name__ == "__main__": + run_tests() diff --git a/train.py b/train.py index ca10c34c..c95e1c7c 100644 --- a/train.py +++ b/train.py @@ -40,7 +40,7 @@ from src.util import create_paths, set_permissions -def train() -> None: +def train() -> OnlineBenchmarkTrainer: # Parse Quinfig (via Quinine Argparse Binding) print("[*] Mercury :: Launching =>>> \N{rocket} \N{see-no-evil monkey} \N{rocket}") print('\t=>> "This wind, it is not an ending..." (Robert Jordan - A Memory of Light)') @@ -199,6 +199,9 @@ def train() -> None: metrics = trainer.evaluate() print(metrics) + # return trainer as record of training process + return trainer + if __name__ == "__main__": train()