Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr eval lora #515

Merged
merged 11 commits into from
Aug 15, 2023
2 changes: 1 addition & 1 deletion TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ lora:
```
- In the current release, these features have Beta support.
- For efficiency, The MPT model concatenates the `Q`, `K`, and `V` matrices in each attention block into a single `Wqkv` matrix that is three times wider. Currently, LoRA supports a low-rank approximation to this `Wqkv` matrix.
- Known issue: PEFT / LoRA do not directly work with FSDP.
- When evaluating with PEFT / LoRA seperated weight, just set `pretrained_lora_id_or_path` in `model`(Find an example [here](scripts/eval/yamls/hf_lora_eval.yml#L19)).

### Can I quantize these models and/or run on CPU?
- The LLM Foundry codebase does not directly have examples of quantization or limited-resource inference. But you can check out [GGML](https://github.com/ggerganov/ggml) (same library that powers llama.cpp) which has built support for efficiently running MPT models on CPU! You _can_ load your model in 8-bit precision for inference using the [bitsandbytes library](https://github.com/TimDettmers/bitsandbytes) and Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index) via `load model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True, device_map="auto", trust_remote_code=True)`, although we have not extensively benchmarked the performance (see the Hugging Face [quantization documentation](https://huggingface.co/docs/transformers/main/main_classes/quantization) for more detail).
Expand Down
55 changes: 52 additions & 3 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,60 @@
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase
from transformers import (AutoModelForCausalLM, PreTrainedTokenizerBase,
T5ForConditionalGeneration)

from llmfoundry.callbacks import ModelGauntlet
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.mpt import MPTForCausalLM
from llmfoundry.utils.builders import (build_icl_evaluators, build_logger,
build_tokenizer)
from llmfoundry.utils.config_utils import process_init_device


def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
num_retries: int) -> Optional[ComposerModel]:
try:
from peft import PeftModel
except ImportError as e:
raise ImportError(
f'Error importing from peft. Run `pip install -e .[gpu,peft]`. \n {e}'
)

model_registry = {
'mpt_causal_lm': MPTForCausalLM,
'hf_causal_lm': AutoModelForCausalLM,
'hf_prefix_lm': AutoModelForCausalLM,
'hf_t5': T5ForConditionalGeneration,
}

retries = 0
while retries < num_retries:
try:
trust_remote_code = model_cfg.get('trust_remote_code', True)
use_auth_token = model_cfg.get('use_auth_token', False)
model = model_registry[model_cfg.name].from_pretrained(
model_cfg.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)

peft_model = PeftModel.from_pretrained(
model, model_cfg.pretrained_lora_id_or_path)

composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name](peft_model,
tokenizer)
return composer_model
except Exception as e:
retries += 1
if retries >= num_retries:
raise e
else:
print(
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)


def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
fsdp_config: Optional[Dict],
num_retries: int) -> Optional[ComposerModel]:
Expand Down Expand Up @@ -76,8 +121,12 @@ def evaluate_model(model_cfg: DictConfig, cfg: DictConfig, run_name: str,
fsdp_config, resolve=True) if fsdp_config is not None else None
assert isinstance(fsdp_config, Dict) or fsdp_config is None

composer_model = load_model(model_cfg.model, tokenizer, fsdp_config,
cfg.get('num_retries', 3))
if hasattr(model_cfg.model, 'pretrained_lora_id_or_path'):
composer_model = load_peft_model(model_cfg.model, tokenizer,
cfg.get('num_retries', 3))
else:
composer_model = load_model(model_cfg.model, tokenizer, fsdp_config,
cfg.get('num_retries', 3))

if model_gauntlet_df is None and model_gauntlet is not None:
model_gauntlet_df = pd.DataFrame(
Expand Down
48 changes: 48 additions & 0 deletions scripts/eval/yamls/hf_lora_eval.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
max_seq_len: 2048
seed: 1
precision: amp_fp16

# If you are using one model, put it here:
model_name_or_path: EleutherAI/gpt-neo-125m
# If you are using a seperated lora weight, put it here:
lora_id_or_path: nathan0/lora-gpt-neo-125m-alpaca
# otherwise, write a block for each model you want to test in the `models` section

models:
-
model_name: ${model_name_or_path}
model:
name: hf_causal_lm
pretrained_model_name_or_path: ${model_name_or_path}
init_device: cpu
pretrained: true
pretrained_lora_id_or_path: ${lora_id_or_path}
tokenizer:
name: ${model_name_or_path}
kwargs:
model_max_length: ${max_seq_len}
# # if you are evaluating more than one model, list them all as YAML blocks without variable interpolation
# -
# model_name: mosaicml/mpt-7b
# model:
# name: hf_causal_lm
# pretrained_model_name_or_path: mosaicml/mpt-7b
# init_device: cpu
# pretrained: true
# config_overrides:
# max_seq_len: ${max_seq_len}
# tokenizer:
# name: mosaicml/mpt-7b
# kwargs:
# model_max_length: ${max_seq_len}


device_eval_batch_size: 4

# FSDP config for model sharding
fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: FULL

icl_tasks: 'eval/yamls/tasks_light.yaml'
model_gauntlet: 'eval/yamls/model_gauntlet.yaml'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
'scipy>=1.10.0,<=1.11.0', # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes
# TODO: pin peft when it stabilizes.
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'peft@git+https://github.com/huggingface/peft.git',
'peft==0.4.0',
]

extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps)
Expand Down
Loading