Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernize MosaicBERT #440

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/benchmarks/bert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def main(cfg: DictConfig,
load_path=cfg.get('load_path', None),
load_weights_only=cfg.get('load_weights_only', False),
python_log_level=cfg.get('python_log_level', None),
)
autoresume=cfg.get('autoresume', None),
fsdp_config=cfg.get('fsdp_config', None),
compile_config=cfg.get('compile_config', None))

print('Logging config...')
log_config(cfg)
Expand Down
10 changes: 5 additions & 5 deletions examples/benchmarks/bert/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
einops==0.5.0
torch==1.13.1
mosaicml[nlp,wandb]>=0.14.0,<0.15
mosaicml-streaming==0.4.1
omegaconf==2.2.3
transformers==4.28.1
torch==2.1.1
composer[nlp,wandb]>=0.17.0,<0.18
mosaicml-streaming<=0.7
omegaconf==2.3.0
transformers==4.35.2
14 changes: 8 additions & 6 deletions examples/benchmarks/bert/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
einops==0.5.0
torch==1.13.1
mosaicml[nlp,wandb]>=0.14.0,<0.15
mosaicml-streaming==0.4.1
omegaconf==2.2.3
transformers==4.28.1
torch==2.1.1
composer[nlp,wandb]>=0.17.0,<0.18
mosaicml-streaming<=0.7
omegaconf==2.3.0
transformers==4.35.2
# need a newer version of FA2
flash_attn>=2.4.2
# need a newer version of triton
triton==2.0.0.dev20221103
#triton==2.0.0.dev20221103
96 changes: 74 additions & 22 deletions examples/benchmarks/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,26 @@
SequenceClassifierOutput)
from transformers.models.bert.modeling_bert import BertPreTrainedModel

IMPL_USE_FLASH2 = False
jacobfulano marked this conversation as resolved.
Show resolved Hide resolved
try:
import flash_attn_triton as flash_attn_triton
flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
import importlib

from flash_attn import flash_attn_qkvpacked_func
installed_version = importlib.metadata.version('flash_attn')
if installed_version < '2.4.2':
raise ImportError('newer version of flash_attn required (>= 2.4.2)')
IMPL_USE_FLASH2 = True
except ImportError as e:
flash_attn_qkvpacked_func = None
warnings.warn(
f'Failed to import flash_attn. Will try to import triton implementation: {e}',
stacklevel=2)
try:
import flash_attn_triton as flash_attn_triton
flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
except ImportError as e:
flash_attn_qkvpacked_func = None
warnings.warn(f'Failed to import flash_attn_triton as a fallback: {e}',
stacklevel=2)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -183,7 +198,8 @@ def __init__(self, config):

def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
max_seqlen_in_batch: int, indices: torch.Tensor,
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
attn_mask: torch.Tensor, bias: torch.Tensor,
slopes: torch.Tensor) -> torch.Tensor:
"""Perform self-attention.
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
Expand All @@ -201,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
indices: (total_nnz,)
attn_mask: (batch, max_seqlen_in_batch)
bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
slopes: (heads) or (batch, heads)
Returns:
attention: (total_nnz, dim)
Expand All @@ -213,7 +230,8 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
'b s (t h d) -> b s t h d',
t=3,
h=self.num_attention_heads)
if self.p_dropout or flash_attn_qkvpacked_func is None:
if (not IMPL_USE_FLASH2 and
self.p_dropout) or flash_attn_qkvpacked_func is None:
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
Expand All @@ -226,19 +244,41 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
3) # b s h d
else:
# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
if convert_dtype:
# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
bias_dtype = bias.dtype
bias = bias.to(torch.float16)
attention = flash_attn_qkvpacked_func(qkv, bias)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
if IMPL_USE_FLASH2:
assert 1 <= len(slopes.shape) <= 2, f'{slopes=}'
assert slopes.shape[
-1] == self.num_attention_heads, f'{slopes=}'

# Triton implementation only supports 0 attention dropout
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
if convert_dtype:
# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this to be in torch.float16?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not, this code was here before though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we select between bfloat16 and float16 though?

bias_dtype = bias.dtype
bias = bias.to(torch.float16)

attention = flash_attn_qkvpacked_func(
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
else:
attention = flash_attn_qkvpacked_func(
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
else:
attention = flash_attn_qkvpacked_func(qkv, bias)
# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
if convert_dtype:
# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
bias_dtype = bias.dtype
bias = bias.to(torch.float16)
attention = flash_attn_qkvpacked_func(qkv, bias)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
else:
attention = flash_attn_qkvpacked_func(qkv, bias)

# attn_mask is 1 for attend and 0 for don't
attention = bert_padding_module.unpad_input_only(
Expand Down Expand Up @@ -291,6 +331,7 @@ def forward(
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for scaled self-attention without padding.
Expand All @@ -303,9 +344,11 @@ def forward(
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen_in_batch)
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
slopes: None or (batch, heads) or (heads,)
"""
assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}'
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
attn_mask, bias)
attn_mask, bias, slopes)
if subset_idx is not None:
return self.output(
bert_padding_module.index_first_axis(self_output, subset_idx),
Expand Down Expand Up @@ -379,6 +422,7 @@ def forward(
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Expand All @@ -391,9 +435,12 @@ def forward(
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen_in_batch)
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
slopes: None or (batch, heads) or (heads,)
"""
assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}'
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
subset_idx, indices, attn_mask, bias)
subset_idx, indices, attn_mask, bias,
slopes)
layer_output = self.mlp(attention_output)
return layer_output

Expand Down Expand Up @@ -463,6 +510,7 @@ def get_slopes_power_of_2(n_heads: int) -> List[float]:
relative_position = relative_position.unsqueeze(0).expand(
n_heads, -1, -1)
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
self.slopes = slopes
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
# [1, n_heads, max_token_length, max_token_length]
alibi = alibi.unsqueeze(0)
Expand Down Expand Up @@ -504,6 +552,7 @@ def forward(
elif self.alibi.device != hidden_states.device:
# Device catch-up
self.alibi = self.alibi.to(hidden_states.device)
self.slopes = self.slopes.to(hidden_states.device)
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
alibi_attn_mask = attn_bias + alibi_bias
Expand All @@ -517,7 +566,8 @@ def forward(
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
bias=alibi_attn_mask,
slopes=self.slopes)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
# Pad inputs and mask. It will insert back zero-padded tokens.
Expand All @@ -536,7 +586,8 @@ def forward(
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
bias=alibi_attn_mask,
slopes=self.slopes)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
Expand All @@ -547,7 +598,8 @@ def forward(
subset_idx=subset_idx,
indices=indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
bias=alibi_attn_mask,
slopes=self.slopes)

if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
Expand Down
3 changes: 1 addition & 2 deletions examples/benchmarks/bert/src/mosaic_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def create_mosaic_bert_mlm(pretrained_model_name: str = 'bert-base-uncased',
pretrained_model_name)

metrics = [
LanguageCrossEntropy(ignore_index=-100,
vocab_size=model.config.vocab_size),
LanguageCrossEntropy(ignore_index=-100),
jacobfulano marked this conversation as resolved.
Show resolved Hide resolved
MaskedAccuracy(ignore_index=-100)
]

Expand Down
8 changes: 0 additions & 8 deletions examples/benchmarks/bert/src/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ class StreamingTextDataset(StreamingDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
keep_raw (bool): Whether to keep or delete the decompressed form (or only form)
of shards after all their samples have been yielded this epoch. If ``False``, keep iff
remote is local or no remote and no compression. Defaults to ``True``.
samples_per_epoch (int, optional): Provide this field iff you are weighting sub-datasets
proportionally. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
Expand Down Expand Up @@ -99,7 +96,6 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
keep_raw: bool = True,
samples_per_epoch: Optional[int] = None,
predownload: int = 100_000,
partition_algo: str = 'orig',
Expand Down Expand Up @@ -140,7 +136,6 @@ def __init__(self,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
keep_raw=keep_raw,
samples_per_epoch=samples_per_epoch,
predownload=predownload,
partition_algo=partition_algo,
Expand Down Expand Up @@ -266,8 +261,6 @@ def build_text_dataloader(
cfg.dataset.get('validate_hash', None),
keep_zip=stream.get('keep_zip', None) or
cfg.dataset.get('keep_zip', False),
keep_raw=stream.get('keep_raw', None) or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting that this is correct and that keep_raw is no longer a flag in mosaicml-streaming (see Streaming docs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check that the defaults here match the defaults currently set in llm foundry?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The defaults in llm foundry are a bit different. Should we update this function whole-hog?

From llmfoundry text_data.py

def __init__(self,
                 tokenizer: PreTrainedTokenizerBase,
                 max_seq_len: int,
                 streams: Optional[Sequence[Stream]] = None,
                 remote: Optional[str] = None,
                 local: Optional[str] = None,
                 split: Optional[str] = None,
                 download_retry: int = 2,
                 download_timeout: float = 60,
                 validate_hash: Optional[str] = None,
                 keep_zip: bool = False,
                 epoch_size: Optional[Union[int, str]] = None,
                 predownload: Optional[int] = None,
                 cache_limit: Optional[Union[int, str]] = None,
                 partition_algo: str = 'relaxed',
                 num_canonical_nodes: Optional[int] = None,
                 batch_size: Optional[int] = None,
                 shuffle: bool = False,
                 shuffle_algo: str = 'py1e',
                 shuffle_seed: int = 9176,
                 shuffle_block_size: Optional[int] = None,
                 sampling_method: str = 'balanced',
                 sampling_granularity: int = 1,
                 batching_method: str = 'random',
                 **kwargs: Any):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's still text data, this should be good!

cfg.dataset.get('keep_raw', True),
))

# build dataset potentially with streams
Expand All @@ -282,7 +275,6 @@ def build_text_dataloader(
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash', None),
keep_zip=cfg.dataset.get('keep_zip', False),
keep_raw=cfg.dataset.get('keep_raw', True),
samples_per_epoch=cfg.dataset.get('samples_per_epoch', None),
predownload=cfg.dataset.get('predownload', 100_000),
partition_algo=cfg.dataset.get('partition_algo', 'orig'),
Expand Down
Loading