From 613457a1cb426ddd601b1b7ee44430be0d8f5ff7 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 27 Nov 2023 15:51:30 -0800 Subject: [PATCH 1/4] Bump composer version to min 0.17.1 (#762) --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index afdfce8d48..9bf2ef2cb0 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17,<0.18', + 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.1,<0.18', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.34.1,<4.35', 'mosaicml-streaming>=0.7.1,<0.8', @@ -84,11 +84,11 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.17,<0.18', + 'mosaicml[databricks]>=0.17.1,<0.18', ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.17,<0.18', + 'mosaicml[tensorboard]>=0.17.1,<0.18', ] extra_deps['gpu'] = [ From 34d04ea689a4e4b08af7e9e911338fd4bc2983c1 Mon Sep 17 00:00:00 2001 From: bandish-shah <86627118+bandish-shah@users.noreply.github.com> Date: Mon, 27 Nov 2023 20:52:14 -0800 Subject: [PATCH 2/4] Update Docker image release logic so that we can release new images to prod from workflow_dispatch (#763) --- .github/workflows/docker.yaml | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 13a835356c..f6dac79fe5 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -69,19 +69,17 @@ jobs: GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7) echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV} - if [ "${{ github.event_name }}" == "push" ]; then - echo "Triggered by push event." - PROD_REPO="mosaicml/llm-foundry" - IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" - IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" - elif [ "${{ github.event_name }}" == "pull_request" ]; then + if [ "${{ github.event_name }}" == "pull_request" ]; then echo "Triggered by pull_request event." STAGING_REPO="mosaicml/ci-staging" IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache" else - echo "Triggered by unknown event: ${{ github.event_name }}" - exit 1 + # Triggered by push or workflow_dispatch event + echo "Triggered by ${{ github.event_name }} event, releasing to prod" + PROD_REPO="mosaicml/llm-foundry" + IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" + IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" fi echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV} From 4f399bf5895b52490c1e43cd0f7d1492724bfa47 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 28 Nov 2023 12:54:45 -0800 Subject: [PATCH 3/4] Fix tiktoken wrapper (#761) --- llmfoundry/tokenizers/tiktoken.py | 169 ++++++++++++++---------------- tests/test_tiktoken.py | 62 ++++++----- 2 files changed, 111 insertions(+), 120 deletions(-) diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 8e258cce74..6110f565df 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -1,8 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - -import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from functools import lru_cache +from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PreTrainedTokenizer @@ -10,6 +9,38 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.""" +# Taken from +# https://github.com/huggingface/transformers/blob/8aca43bdb3cb9a5020f6d57589d85679dc873b1c/src/transformers/models/gpt2/tokenization_gpt2.py#L62-L84 +@lru_cache() +def bytes_to_unicode(): + """Returns list of utf-8 byte and a mapping to unicode strings. + + We specifically avoids mapping to whitespace/control characters the bpe code + barfs on. + + The reversible bpe codes work on unicode strings. This means you need a + large # of unicode characters in your vocab if you want to avoid UNKs. When + you're at something like a 10B token dataset you end up needing around 5K + for decent coverage. This is a significant percentage of your normal, say, + 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and + unicode strings. + """ + bs = (list(range(ord('!'), + ord('~') + 1)) + list(range(ord('¡'), + ord('¬') + 1)) + + list(range(ord('®'), + ord('ÿ') + 1))) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + class TiktokenTokenizerWrapper(PreTrainedTokenizer): """A thin wrapper around tiktoken to make it compatible with Hugging Face. @@ -93,6 +124,28 @@ def pickle_Encoding(enc: Encoding): self.add_eos_token = add_eos_token self.use_default_system_prompt = use_default_system_prompt + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + self.decoder: Dict[int, str] = {} + for i in range(self.encoding.n_vocab): + try: + self.encoding.decode_single_token_bytes(i) + except KeyError: + continue + # Taken from + # https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee + decoding = ''.join([ + bytes_to_unicode()[ord(char)] for char in + self.encoding.decode_single_token_bytes(i).decode('latin-1') + ]) + self.decoder[i] = decoding + + self.encoder: Dict[str, int] = {} + for i in range(self.encoding.n_vocab): + if i in self.decoder: + self.encoder[self.decoder[i]] = i + super().__init__(model_name=model_name, encoding_name=encoding_name, add_bos_token=add_bos_token, @@ -135,122 +188,54 @@ def default_chat_template(self): return template def get_vocab(self) -> Dict[str, int]: - """Returns vocab as a dict. - - Note: This function does not work properly due to difference in assumptions between tiktoken and Hugging Face tokenizers. - Most uses do not need to use get_vocab, so this is not a priority to fix. - """ - warnings.warn( - 'get_vocab does not work properly with TiktokenTokenizerWrapper. Please do not rely on it being perfectly correct.' - + - ' It will be called once init just to get the size of the vocab inside the base class.' - ) - - vocab = {} - for i in range(self.vocab_size): - try: - # need to try this first, so that we get a proper KeyError, - # otherwise it crashes in the rust code - _ = self.encoding.decode_single_token_bytes(i) - vocab[self.encoding.decode([i])] = i - except KeyError: - pass - + """Returns vocab as a dict.""" # As far as I can tell, we don't require get_vocab to completely work, # but when using additional_special_tokens, Hugging Face determines the next # token index to add with len(self.get_vocab()) so we need the _size_ of this dictionary to be correct. + vocab_clone = self.encoder.copy() extra_id_index = 0 candidate_extra_id = f'' indices_to_fill_in = {i for i in range(self.vocab_size)} - set( - vocab.values()) + vocab_clone.values()) # Add enough indices to make get_vocab() the right length for index_to_add in indices_to_fill_in: # Make sure we don't overwrite a token that already exists - while candidate_extra_id in vocab: + while candidate_extra_id in vocab_clone: extra_id_index += 1 candidate_extra_id = f'' # Get an index to add and add the item - vocab[candidate_extra_id] = index_to_add + vocab_clone[candidate_extra_id] = index_to_add - return vocab + return vocab_clone - def _tokenize(self, text: str) -> List[int]: - """Returns a tokenized string. - - Note: We have slightly redefined the expected contract between this method and - the _convert_token_to_id method. Normally, this method turns a string, into a list of strings, - and then the _convert_token_to_id method turns that list of strings into a list of integers. - However, not all vocab indices can be decoded into a string, so instead we just return the integers - from this function, and have adjusted the _convert_token_to_id method to handle integers as well as strings. - The only use of _tokenize that I could find was in this way, so this _should_ be safe. - """ + def _tokenize(self, text: str) -> List[str]: + """Returns a tokenized string.""" if not isinstance(text, str): raise ValueError( f'Expected a string input to _tokenize but got {type(text)}.') - tokens = [t for t in self.encoding.encode(text, allowed_special='all')] + tokens = [ + self.decoder[t] + for t in self.encoding.encode(text, allowed_special='all') + ] return tokens - def _convert_token_to_id(self, token: Union[int, str]) -> int: - """Converts a token (str) into an id using the vocab.""" - if isinstance(token, int): - return token + def _convert_token_to_id(self, token: str) -> Optional[int]: + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) - return self.encoding.encode(token, allowed_special='all')[0] - - def _convert_id_to_token(self, index: int) -> str: - """Converts an index (integer) into a token (str) using the vocab.""" - return self.encoding.decode([index]) + def _convert_id_to_token(self, index: int) -> Optional[str]: + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens (string) in a single string.""" - return ''.join(tokens) - - def convert_ids_to_tokens( - self, - ids: Union[int, List[int]], - skip_special_tokens: bool = False) -> Union[str, List[str]]: - """Converts a single index or a sequence of indices into a token or a. - - sequence of tokens, using the vocabulary and added tokens. - - Args: - ids (`int` or `List[int]`): - The token id (or token ids) to convert to tokens. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - - Returns: - `str` or `List[str]`: The decoded token(s). - """ - if isinstance(ids, int): - if ids in self.added_tokens_decoder: - return str(self.added_tokens_decoder[ids]) - - return self._convert_id_to_token(ids) - - # current_stream will collect multiple tokens, and then separately add items - # for each added token. This is done so that decode works properly with token ids - # that cannot be represented naively in utf-8. - tokens = [] - current_stream = [] - for index in ids: - if skip_special_tokens and index in self.all_special_ids: - continue - - if index in self.added_tokens_decoder: - tokens.append(self.encoding.decode(current_stream)) - current_stream = [] - tokens.append(str(self.added_tokens_decoder[index])) - else: - current_stream.append(index) - - if len(current_stream) > 0: - tokens.append(self.encoding.decode(current_stream)) - return tokens + text = ''.join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8') + return text def build_inputs_with_special_tokens( self, diff --git a/tests/test_tiktoken.py b/tests/test_tiktoken.py index 5bd10c82d3..fe3db41d50 100644 --- a/tests/test_tiktoken.py +++ b/tests/test_tiktoken.py @@ -7,7 +7,8 @@ import pytest import transformers -from llmfoundry import TiktokenTokenizerWrapper +from llmfoundry.tokenizers.tiktoken import (TiktokenTokenizerWrapper, + bytes_to_unicode) from tests.horrible_strings import HORRIBLE_STRINGS from tests.test_hf_conversion_script import check_hf_tokenizer_equivalence @@ -29,18 +30,12 @@ TEST_STRINGS += HORRIBLE_STRINGS -MODEL_OR_ENCODING_NAME_TO_NON_UTF8_TOKENS = { - 'gpt-4': 77, - 'gpt-3.5-turbo': 77, - 'text-davinci-003': 14, - 'cl100k_base': 77, -} - MODEL_ENCODING_NAME_PARAMETRIZATION = [ ('gpt-4', None), ('gpt-3.5-turbo', None), ('text-davinci-003', None), (None, 'cl100k_base'), + ('gpt2', None), ] MULTI_TURN_CHAT_ML = [[{ @@ -120,6 +115,31 @@ def test_tiktoken_simple(model_name: Optional[str], assert reloaded_wrapped_output == wrapped_output +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_tokenize_with_ids(model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + + for string in TEST_STRINGS: + wrapped_output = wrapped_tokenizer.tokenize(string) + original_output = original_tokenizer.encode(string, + allowed_special='all') + reloaded_wrapped_output = reloaded_wrapped_tokenizer.tokenize(string) + + assert all([isinstance(t, str) for t in wrapped_output]) + assert len(wrapped_output) == len(original_output) + assert wrapped_output == reloaded_wrapped_output + + redone_token_ids = wrapped_tokenizer.convert_tokens_to_ids( + wrapped_output) + assert redone_token_ids == original_output + assert wrapped_tokenizer.convert_ids_to_tokens( + redone_token_ids) == wrapped_output + + @pytest.mark.parametrize('model_name,encoding_name', MODEL_ENCODING_NAME_PARAMETRIZATION) def test_tiktoken_roundtrip(model_name: Optional[str], @@ -201,31 +221,17 @@ def test_tiktoken_vocab(model_name: Optional[str], encoding_name: Optional[str], reloaded_wrapped_vocab = reloaded_wrapped_tokenizer.get_vocab() assert wrapped_vocab == reloaded_wrapped_vocab - didnt_match = [] for key, value in wrapped_vocab.items(): # Skip checking the extra ids we pad the vocab with if key.startswith(''): continue - if original_tokenizer.encode(key, allowed_special='all') == [value]: - continue - else: - didnt_match.append( - (key, original_tokenizer.encode(key, - allowed_special='all'), value)) - - # Decode is lossy because some bytes are not representable in utf-8 - # see https://github.com/openai/tiktoken/blob/39f29cecdb6fc38d9a3434e5dd15e4de58cf3c80/tiktoken/core.py#L245-L247 - # This means that the str: int vocab mapping doesn't work. Would have to look more into how other HF tokenizers handle this. - model_or_encoding_name = model_name or encoding_name - if model_or_encoding_name is not None: - expected_didnt_match = MODEL_OR_ENCODING_NAME_TO_NON_UTF8_TOKENS.get( - model_or_encoding_name) - assert len(didnt_match) == expected_didnt_match - else: - raise NotImplementedError( - 'Add the new tokenizer and how many tokens in the vocab are not utf8 representable.' - ) + expected_decoding = ''.join([ + bytes_to_unicode()[ord(char)] + for char in original_tokenizer.decode_single_token_bytes( + value).decode('latin-1') + ]) + assert expected_decoding == key @pytest.mark.parametrize('model_name,encoding_name', From 5f21855cb35987ec73fe8b3a515f6ae3db903d56 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Tue, 28 Nov 2023 16:18:18 -0800 Subject: [PATCH 4/4] enable param group configuration in llm-foundry (#760) * enable param group configuration in llm-foundry * add doc string * add debug logs * add test, fix bug * spell check; mark test gpu * updt to use RegEx search * Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * updt with dakinggg pr comments --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/optim/lion8b.py | 22 ++++--- llmfoundry/utils/builders.py | 118 +++++++++++++++++++++++++++++++++-- tests/test_builders.py | 89 +++++++++++++++++++++++++- 3 files changed, 211 insertions(+), 18 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 2c2e6e2d35..9d1d1dda71 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): device, or b) step() is executed on a non-CUDA parameter. """ - def __init__(self, - params: Iterable[torch.Tensor], - lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0, - quantize: bool = True, - compress_state_dict: bool = False, - error_correction: bool = False, - _fused: bool = True): # XXX this flag is mostly for testing... + def __init__( + self, + params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0, + quantize: bool = True, + compress_state_dict: bool = False, + error_correction: bool = False, + _fused: bool = True, # XXX this flag is mostly for testing... + ): if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index c31917efc6..14196c3ef9 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -1,10 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import functools import logging import os +import re import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from collections import OrderedDict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from composer import algorithms @@ -155,18 +158,121 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: raise ValueError(f'Not sure how to build algorithm: {name}') +def _extract_param_groups( + model: torch.nn.Module, + optimizer_config: Dict[str, Any], +) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: + """Extracts parameter groups defined in the optimizer config. + + The optimizer_config defines the optimizer args. It can additionally have key + `disable_grad` which is a string or list of strings. If a string matches a + parameter name, then that parameter will have `requires_grad=False`. This is + useful for freezing parameters. It can additionally have a key + `param_groups` which is a list of dicts. In this dict, key `param_str_match` + defines a string; if a parameter name contains this string, then it will be + in this parameter group. This is useful for grouping parameters together. + The dict can also contain any other key that is a valid optimizer arg. + Note: to handle name overlap conflicts, params are assigned to parameter + groups and added to `param_groups` in the order that `param_str_match` appear + in `param_groups`. + + Usage + To disable gradient for all parameters that contain the string "norm" or "bias": + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "disable_grad": ["norm", "bias"] + } + ``` + + To create and modify the optimizer parameters for all parameters that contain + the string "norm" and "bias" separately: + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "param_groups": [ + { + "param_str_match": "norm", + "lr": 1e-4, + "weight_decay": 0.0, + }, + { + "param_str_match": "bias", + "lr": 5e-4, + "weight_decay": 0.0, + }, + ], + } + ``` + + Args: + model (torch.nn.Module): model to extract parameters from + optimizer_config (Dict[str, Any]): optimizer config + + Returns: + Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of + torch.Tensor's or dict's. Specifies what Tensors should be optimized + and their param groupings. + """ + if 'disable_grad' in optimizer_config.keys(): + str_matches = optimizer_config.pop('disable_grad') + if isinstance(str_matches, str): + str_matches = [str_matches] + for str_match in str_matches: + for n, p in model.named_parameters(): + if re.search(str_match, n): + p.requires_grad = False + log.debug(f'Setting `{n}.requires_grad = False`.') + + param_groups_config = optimizer_config.pop('param_groups', None) + if param_groups_config is not None: + params = [] + param_dict = OrderedDict((n, p) for n, p in model.named_parameters()) + + log.debug(f'Default optimizer settings: {optimizer_config}.') + for param_group_config in param_groups_config: + str_match = param_group_config.pop('param_str_match') + filter_fn = functools.partial(re.search, str_match) + param_names = [n for n in param_dict.keys() if filter_fn(n)] + group_params = {'params': [param_dict.pop(n) for n in param_names]} + group_params.update(param_group_config) + + log.debug( + f'Creating optimizer param_group with parameters: {param_names} ' +\ + f'(extracted using {str_match=}). The param_group optimizer ' +\ + f'setting overrides are: {param_group_config}.') + + params.append(group_params) + + params.insert(0, {'params': param_dict.values()}) + return params + + return model.parameters() + + def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Dict[str, Any]) -> Optimizer: + + params = _extract_param_groups(model, optimizer_config) + if name == 'decoupled_adamw': - return DecoupledAdamW(model.parameters(), **optimizer_config) + return DecoupledAdamW(params, **optimizer_config) elif name == 'decoupled_lionw': - return DecoupledLionW(model.parameters(), **optimizer_config) + return DecoupledLionW(params, **optimizer_config) elif name == 'clip_lion': - return DecoupledClipLion(model.parameters(), **optimizer_config) + return DecoupledClipLion(params, **optimizer_config) elif name == 'adalr_lion': - return DecoupledAdaLRLion(model.parameters(), **optimizer_config) + return DecoupledAdaLRLion(params, **optimizer_config) elif name == 'decoupled_lionw_8b': - return DecoupledLionW_8bit(model.parameters(), **optimizer_config) + return DecoupledLionW_8bit(params, **optimizer_config) else: raise ValueError(f'Not sure how to build optimizer: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index 237e27b52b..7ac179720e 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -1,17 +1,22 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import re import unittest.mock as mock -from typing import Union +from copy import deepcopy +from typing import Any, Dict, Union import pytest +import torch +import torch.nn as nn from composer.callbacks import Generate from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import build_callback, build_tokenizer +from llmfoundry.utils.builders import (build_callback, build_optimizer, + build_tokenizer) @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -110,3 +115,83 @@ def test_build_hf_checkpointer_callback(): assert isinstance(kwargs['mlflow_logging_config'], dict) assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict) assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict + + +class _DummyModule(nn.Module): + + def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32): + super().__init__() + self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) + self.norm0 = nn.LayerNorm(3, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore + return self.linear1(self.norm0(self.linear0(x))) + + +@pytest.mark.parametrize('name, optimizer_config', [ + ('decoupled_adamw', {}), + ('decoupled_lionw', {}), + ('clip_lion', {}), + ('adalr_lion', {}), + pytest.param('decoupled_lionw_8b', {}, marks=pytest.mark.gpu), +]) +@pytest.mark.parametrize('opt_additional_config', [ + { + 'disable_grad': 'norm' + }, + { + 'disable_grad': ['norm', 'bias'] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, + { + 'param_groups': [{ + 'param_str_match': 'no.*.bias', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-4, + 'weight_decay': 0.0, + },], + 'disable_grad': ['bias'], + }, +]) +def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], + opt_additional_config: Dict[str, Any]): + model = _DummyModule() + optimizer_config.update(deepcopy(opt_additional_config)) + optimizer = build_optimizer(model, name, optimizer_config) + + if 'disable_grad' in opt_additional_config.keys(): + disable_grad = opt_additional_config['disable_grad'] + if isinstance(disable_grad, str): + disable_grad = [disable_grad] + for n, p in model.named_parameters(): + for k in disable_grad: + if re.search(k, n): + assert not p.requires_grad + + if 'param_groups' in opt_additional_config.keys(): + for param_group_config, param_group in zip( + opt_additional_config['param_groups'], + optimizer.param_groups[1:]): + param_group_config = deepcopy(param_group_config) + param_str_match = param_group_config.pop('param_str_match') + + for k, v in param_group_config.items(): + assert param_group[k] == v + + param_ids = [id(p) for p in param_group['params']] + for n, p in model.named_parameters(): + if re.search(param_str_match, n): + assert id(p) in param_ids