Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed inference #255

Merged
merged 5 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 97 additions & 30 deletions open_flamingo/eval/eval_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import abc
import argparse
from typing import List
from torch.nn.parallel import DistributedDataParallel as DDP
from PIL import Image

from utils import get_autocast, get_cast_dtype
import torch
from contextlib import suppress

class BaseEvalModel(abc.ABC):
"""Base class encapsulating functionality needed to evaluate a model."""

def __init__(self, args: List[str]):
def __init__(self, model_args: List[str]):
"""Initialize model.

Args:
Expand All @@ -17,23 +18,107 @@ def __init__(self, args: List[str]):
is non-empty.
"""

def init_distributed(self):
"""Wrap model as DDP."""
self.model = DDP(self.model, device_ids=[self.device])
def __init__(self, model_args, init_on_device=False):
assert "lm_path" in model_args, "All models require the lm_path argument"
self.device = (
model_args["device"]
if ("device" in model_args and (type(model_args["device"]) != int or model_args["device"] >= 0))
else "cpu"
)
self.precision = model_args.get("precision", "fp32")
self.lm_name = model_args["lm_path"].split("/")[-1]
self.autocast = get_autocast(self.precision)
self.cast_dtype = get_cast_dtype(self.precision)
if init_on_device:
# for deepspeed, must init on device, or likely CPU OOM
import deepspeed
self.init_ctx = deepspeed.OnDevice(dtype=self.cast_dtype, device=self.device)
else:
self.init_ctx = suppress()

def _check_init(self):
"""Finish model initialization."""
assert hasattr(self, "model"), "Model has not been initialized"
self.model.eval()
self.model.to(self.device, dtype=self.cast_dtype)
assert hasattr(self, "tokenizer"), "Tokenizer has not been initialized"
self.tokenizer.padding_side = "left"

def init_distributed(self, world_size=None, use_deepspeed=False):
"""Wrap model as DDP or deepspeed."""
if use_deepspeed:
assert "amp" not in self.precision, "Deepspeed does not support amp"
import deepspeed
self.ds_engine = deepspeed.init_inference(
self.model,
mp_size=world_size,
dtype=self.cast_dtype,
checkpoint=None,
replace_with_kernel_inject=True,
)
self.model = self.ds_engine.module
self.autocast = get_autocast(None)
else:
self.model = DDP(self.model, device_ids=[self.device])

def set_device(self, device):
"""Set device for model."""
self.device = device
self.model = self.model.to(device)
torch.cuda.set_device(device)
self.device = torch.device("cuda", device)
self.model = self.model.to(device, dtype=self.cast_dtype)

def __call__(
self,
lang_x: torch.Tensor,
vision_x: torch.Tensor,
attention_mask: torch.Tensor,
past_key_values: torch.Tensor = None,
use_cache: bool = False,
):
"""
Calls the forward function of the model.
Special logic to handle the case if past_key_values is not None:
then lang_x is assumed to contain the tokens to be generated
*excluding* the tokens already in past_key_values.
We then repeatedly call forward, updating the past_key_values.
"""

def prepare_text(
self,
batch: List[List[str]],
padding="longest",
truncation=True,
max_length=2000,
add_special_tokens=True,
):
"""
Prepare text for model.

Args:
batch: list of text strings
padding: whether to pad the text
truncation: whether to truncate the text
max_length: maximum length of the text

Returns:
input_ids: tensor of shape (B, T)
attention_mask: tensor of shape (B, T)
"""

def prepare_images(self, batch: List[List[Image.Image]]):
"""
Prepare images for model.
Args:
batch: list of lists of PIL images
Returns:
tensor of shape (B, T, *, C, H, W)
"""

def get_outputs(
self,
batch_text: List[str],
batch_images: List[List[Image.Image]],
min_generation_length: int,
max_generation_length: int,
num_beams: int,
length_penalty: float,
**decode_kwargs,
) -> List[str]:
"""Get outputs for a batch of images and text.

Expand All @@ -42,29 +127,11 @@ def get_outputs(
of any images to be included.
batch_images: images to provide to model. Should be a list of lists,
where each list contains the images for a single example.
max_generation_length: maximum length of the generated caption.
Defaults to 10.
num_beams: number of beams to use for beam search. Defaults to 3.
length_penalty: length penalty for beam search. Defaults to -2.0.

Returns:
List of decoded output strings.
"""

def vqa_prompt(self, question, answer=None) -> str:
"""Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model.

Returns:
The prompt to use for VQA.
"""

def caption_prompt(self, caption=None) -> str:
"""Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model.

Returns:
The prompt to use for captioning.
"""

def get_rank_classifications(
self,
batch_text: List[str],
Expand Down
50 changes: 31 additions & 19 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,28 +382,40 @@
action="store_true",
help="Use horovod for distributed training.",
)
parser.add_argument(
"--local_rank",
default=0,
type=int,
help="Rank of distributed process (default: 0). Usually overwritten by world_info_from_env()",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)

parser.add_argument(
"--deepspeed",
default=False,
action="store_true",
help="Whether to use deepspeed for distributed inference.",
)

def main():
args, leftovers = parser.parse_known_args()
module = importlib.import_module(f"open_flamingo.eval.models.{args.model}")

model_args = {
leftovers[i].lstrip("-"): leftovers[i + 1] for i in range(0, len(leftovers), 2)
}
eval_model = module.EvalModel(model_args)

# set up distributed evaluation
args.local_rank, args.rank, args.world_size = world_info_from_env()
device_id = init_distributed_device(args)
eval_model.set_device(device_id)
eval_model.init_distributed()
model_args = {
leftovers[i].lstrip("-"): leftovers[i + 1] for i in range(0, len(leftovers), 2)
}
model_args['device'] = device_id
eval_model = module.EvalModel(model_args, init_on_device=args.deepspeed)
eval_model.init_distributed(
local_rank=args.local_rank, world_size=args.world_size, use_deepspeed=args.deepspeed
)

if args.model != "open_flamingo" and args.shots != [0]:
raise ValueError("Only 0 shot eval is supported for non-open_flamingo models")
Expand Down Expand Up @@ -618,7 +630,7 @@ def main():
num_shots=shot,
seed=seed,
dataset_name="textvqa",
max_generation_length=10,
max_new_tokens=10,
cached_features=cached_features,
)
if args.rank == 0:
Expand Down Expand Up @@ -729,8 +741,8 @@ def evaluate_captioning(
args: argparse.Namespace,
eval_model: BaseEvalModel,
seed: int = 42,
min_generation_length: int = 0,
max_generation_length: int = 20,
min_new_tokens: int = 0,
max_new_tokens: int = 20,
num_beams: int = 3,
length_penalty: float = 0.0,
num_shots: int = 8,
Expand All @@ -743,7 +755,7 @@ def evaluate_captioning(
args (argparse.Namespace): arguments
eval_model (BaseEvalModel): model to evaluate
seed (int, optional): seed for random number generator. Defaults to 42.
max_generation_length (int, optional): maximum length of the generated caption. Defaults to 20.
max_new_tokens (int, optional): maximum length of the generated caption. Defaults to 20.
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
num_shots (int, optional): number of in-context samples to use. Defaults to 8.
Expand Down Expand Up @@ -843,8 +855,8 @@ def evaluate_captioning(
outputs = eval_model.get_outputs(
batch_images=batch_images,
batch_text=batch_text,
min_generation_length=min_generation_length,
max_generation_length=max_generation_length,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
)
Expand Down Expand Up @@ -900,8 +912,8 @@ def evaluate_vqa(
args: argparse.Namespace,
eval_model: BaseEvalModel,
seed: int = 42,
min_generation_length: int = 0,
max_generation_length: int = 5,
min_new_tokens: int = 0,
max_new_tokens: int = 5,
num_beams: int = 3,
length_penalty: float = 0.0,
num_shots: int = 8,
Expand All @@ -915,7 +927,7 @@ def evaluate_vqa(
args (argparse.Namespace): arguments
eval_model (BaseEvalModel): model to evaluate
seed (int, optional): random seed. Defaults to 42.
max_generation_length (int, optional): max generation length. Defaults to 5.
max_new_tokens (int, optional): max generation length. Defaults to 5.
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
num_shots (int, optional): number of shots to use. Defaults to 8.
Expand Down Expand Up @@ -1036,8 +1048,8 @@ def evaluate_vqa(
outputs = eval_model.get_outputs(
batch_images=batch_images,
batch_text=batch_text,
min_generation_length=min_generation_length,
max_generation_length=max_generation_length,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
)
Expand Down
Loading
Loading