Skip to content

Commit

Permalink
Merge branch 'batch_code_eval' of github.com:mosaicml/composer into b…
Browse files Browse the repository at this point in the history
…atch_code_eval
  • Loading branch information
josejg committed Feb 14, 2024
2 parents bea48f6 + c4ba100 commit 7fbabd9
Show file tree
Hide file tree
Showing 23 changed files with 353 additions and 378 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/daily.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ jobs:
code-eval-apikey: ${{ secrets.CODE_EVAL_APIKEY }}
gcs-key: ${{ secrets.GCS_KEY }}
gcs-secret: ${{ secrets.GCS_SECRET }}
azure-account-name: ${{ secrets.AZURE_ACCOUNT_NAME }}
azure-account-access-key: ${{ secrets.AZURE_ACCOUNT_ACCESS_KEY }}
coverage:
uses: ./.github/workflows/coverage.yaml
name: Coverage Results
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/pytest-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ on:
required: false
gcs-secret:
required: false
azure-account-name:
required: false
azure-account-access-key:
required: false
jobs:
pytest-cpu:
timeout-minutes: 30
Expand Down Expand Up @@ -75,6 +79,8 @@ jobs:
export CODE_EVAL_APIKEY='${{ secrets.code-eval-apikey }}'
export GCS_KEY='${{ secrets.gcs-key }}'
export GCS_SECRET='${{ secrets.gcs-secret }}'
export AZURE_ACCOUNT_NAME='${{ secrets.azure-account-name }}'
export AZURE_ACCOUNT_ACCESS_KEY='${{ secrets.azure-account-access-key }}'
export S3_BUCKET='${{ inputs.pytest-s3-bucket }}'
export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}' --s3_bucket '$S3_BUCKET' \
-o tmp_path_retention_policy=none"
Expand Down
12 changes: 6 additions & 6 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
return

metadata_local_file_path = None
if dist.get_global_rank() == 0 and state.fsdp_elastic_sharded_enabled:
if dist.get_global_rank() == 0 and state.fsdp_sharded_state_dict_enabled:
metadata_local_file_path = format_name_with_dist_and_time(
os.path.join(Path(saved_path).parent, _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME), state.run_name,
state.timestamp)
Expand All @@ -407,11 +407,11 @@ def _save_checkpoint(self, state: State, logger: Logger):
except FileNotFoundError:
pass
# Sharded checkpoints for torch >2.0 use directories not files for load_paths
if state.fsdp_elastic_sharded_enabled:
if state.fsdp_sharded_state_dict_enabled:
src_path = str(pathlib.Path(saved_path).parent)
else:
src_path = saved_path
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_elastic_sharded_enabled
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if this_rank_saves_symlinks:
os.symlink(os.path.relpath(src_path, os.path.dirname(symlink)), symlink)

Expand All @@ -430,7 +430,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
# Upload metadata file.
# The metadata file contains info related to which shards are saved where.
if dist.get_global_rank() == 0 and state.fsdp_elastic_sharded_enabled:
if dist.get_global_rank() == 0 and state.fsdp_sharded_state_dict_enabled:
metadata_remote_file_name = format_name_with_dist_and_time(
os.path.join(Path(remote_file_name).parent, _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME),
state.run_name, state.timestamp)
Expand Down Expand Up @@ -463,12 +463,12 @@ def _save_checkpoint(self, state: State, logger: Logger):
with tempfile.TemporaryDirectory() as tmpdir:
symlink_filename = os.path.join(tmpdir, 'latest.symlink')
# Sharded checkpoints for torch >2.0 use directories not files for load_paths
if state.fsdp_elastic_sharded_enabled:
if state.fsdp_sharded_state_dict_enabled:
src_path = str(pathlib.Path(remote_file_name).parent)
else:
src_path = remote_file_name
log.debug(f'Creating symlink file {symlink_filename} -> {src_path}')
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_elastic_sharded_enabled
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if this_rank_saves_symlinks:
create_symlink_file(src_path, symlink_filename)
logger.upload_file(
Expand Down
9 changes: 5 additions & 4 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,11 @@ def fsdp_state_dict_type(self):
def fsdp_sharded_state_dict_enabled(self):
return self.fsdp_config is not None and self.fsdp_enabled and self.fsdp_state_dict_type == 'sharded'

@property
def fsdp_elastic_sharded_enabled(self):
warnings.warn('state.fsdp_elastic_sharded_enabled is deprecated and will be removed v0.21.0')
return self.fsdp_sharded_state_dict_enabled

@property
def fsdp_device_mesh(self):
if self.fsdp_enabled:
Expand All @@ -745,10 +750,6 @@ def load_fsdp_monolith_rank0_only(self):
return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[
'state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True

@property
def fsdp_elastic_sharded_enabled(self):
return self.fsdp_sharded_state_dict_enabled

def _get_integrations_state_dict(self) -> Dict[str, Any]:
"""Gets a dictionary of information about integrations to store in the state dict.
Expand Down
4 changes: 2 additions & 2 deletions composer/loggers/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def epoch_end(self, state: State, logger: Logger) -> None:
cur_epoch = int(state.timestamp.epoch) # epoch gets incremented right before EPOCH_END
unit = self.log_interval.unit

if unit == TimeUnit.EPOCH and (cur_epoch % int(self.log_interval) == 0 or cur_epoch == 1):
if unit == TimeUnit.EPOCH and (cur_epoch % int(self.log_interval) == 0 or self.last_logged_batch == 0):
self.log_to_console(self.logged_metrics, prefix='Train ', state=state)
self.last_logged_batch = int(state.timestamp.batch)
self.logged_metrics = {} # Clear logged metrics.

def batch_end(self, state: State, logger: Logger) -> None:
cur_batch = int(state.timestamp.batch)
unit = self.log_interval.unit
if unit == TimeUnit.BATCH and (cur_batch % int(self.log_interval) == 0 or cur_batch == 1):
if unit == TimeUnit.BATCH and (cur_batch % int(self.log_interval) == 0 or self.last_logged_batch == 0):
self.log_to_console(self.logged_metrics, prefix='Train ', state=state)
self.last_logged_batch = cur_batch
self.logged_metrics = {} # Clear logged metrics.
Expand Down
1 change: 0 additions & 1 deletion composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def _flush_metadata(self, force_flush: bool = False, future: bool = True) -> Non
self.buffered_metadata = {}
self.time_last_logged = time.time()
done, incomplete = wait(self._futures, timeout=0.01)
log.info(f'Logged {len(done)} metadata to MosaicML, waiting on {len(incomplete)}')
# Raise any exceptions
for f in done:
if f.exception() is not None:
Expand Down
96 changes: 81 additions & 15 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import string
import warnings
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -197,9 +197,20 @@ def compute(self) -> Tensor:

class InContextLearningMetric(Metric):

def update(self, batch: dict, output_logits: torch.Tensor, labels: torch.Tensor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.needs_batch = True

def update(self,
batch: dict,
output_logits: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
outputs: Optional[torch.Tensor] = None):
"""Abstract interface for computing an in-context learning metrics.
The `output_logits` argument is deprecated and will be removed in v0.21 while it's functionality will
be moved to `outputs`.
Args:
batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed
to compute the metric.
Expand All @@ -211,6 +222,27 @@ def update(self, batch: dict, output_logits: torch.Tensor, labels: torch.Tensor)
"""
raise NotImplementedError

@staticmethod
def rename_args(batch: dict,
output_logits: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
outputs: Optional[torch.Tensor] = None) -> Tuple[dict, torch.Tensor, torch.Tensor]:
if outputs is not None and output_logits is not None:
raise ValueError('Cannot use both `outputs` and `output_logits`')
if output_logits is not None:
warnings.warn(
('`output_logits` has been renamed to `outputs` and will be removed in v0.21'),
DeprecationWarning,
)
outputs = output_logits

if labels is None:
raise ValueError('`labels` cannot be None')
if outputs is None:
raise ValueError('`outputs` cannot be None')

return batch, outputs, labels


class InContextLearningQAAccuracy(InContextLearningMetric):
r"""Computes accuracy for In-context learning (ICL) question answering (QA) tasks.
Expand Down Expand Up @@ -267,9 +299,7 @@ def replace_underscore(text: str) -> str:

return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(answer))))).strip()

def update(self, outputs: List[str], labels: List[List[str]], batch: Optional[Dict[str, Any]] = None):
if batch is None:
batch = {}
def update(self, outputs: List[str], labels: List[List[str]], batch: Dict[str, Any]):
cot_delimiter = batch.get('cot_delimiter', '')
do_normalization = batch.get('do_normalization', True)
stopping_criteria = batch.get('stopping_criteria', None)
Expand Down Expand Up @@ -329,9 +359,18 @@ def __init__(self, dist_sync_on_step: bool = False):
self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum')

def update(self, batch: dict, output_logits: torch.Tensor, labels: torch.Tensor):
def update(self,
batch: dict,
output_logits: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
outputs: Optional[torch.Tensor] = None):
batch, outputs, labels = InContextLearningMetric.rename_args(batch=batch,
output_logits=output_logits,
labels=labels,
outputs=outputs)

for batch_idx, cont_idx in enumerate(batch['continuation_indices']):
cont_tok_pred = output_logits[batch_idx].index_select(dim=0, index=cont_idx - 1).argmax(dim=-1)
cont_tok_pred = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1).argmax(dim=-1)
cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1)

self.correct += (cont_tok_pred == cont_tok_targ).all().int()
Expand Down Expand Up @@ -371,11 +410,20 @@ def __init__(self, dist_sync_on_step: bool = False):
self.add_state('correct', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum')

def update(self, batch: dict, output_logits: torch.Tensor, labels: torch.Tensor):
def update(self,
batch: dict,
output_logits: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
outputs: Optional[torch.Tensor] = None):
batch, outputs, labels = InContextLearningMetric.rename_args(batch=batch,
output_logits=output_logits,
labels=labels,
outputs=outputs)

perplexities = []
for batch_idx, cont_idx in enumerate(batch['continuation_indices']):
# continuation indices refer to indices in the original input's token space
cont_tok_logits = output_logits[batch_idx].index_select(dim=0, index=cont_idx - 1)
cont_tok_logits = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1)
# labels have been shifted left by one index, so the cont_idx needs to be shifted as well.
cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1)
cross_entropy = F.cross_entropy(cont_tok_logits, cont_tok_targ)
Expand Down Expand Up @@ -456,11 +504,20 @@ class InContextLearningMCExpectedCalibrationError(InContextLearningExpectedCalib
# Make torchmetrics call update only once
full_state_update = False

def update(self, batch: Dict[str, Any], output_logits: torch.Tensor, labels: torch.Tensor):
output_logits = torch.softmax(output_logits, dim=2)
def update(self,
batch: dict,
output_logits: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
outputs: Optional[torch.Tensor] = None):
batch, outputs, labels = InContextLearningMetric.rename_args(batch=batch,
output_logits=output_logits,
labels=labels,
outputs=outputs)

outputs = torch.softmax(outputs, dim=2)
probabilites = []
for batch_idx, cont_idx in enumerate(batch['continuation_indices']):
cont_tok_logits = output_logits[batch_idx].index_select(dim=0, index=cont_idx - 1)
cont_tok_logits = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1)
cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1)
probability = cont_tok_logits.index_select(dim=1, index=cont_tok_targ).diagonal().mean()
probabilites.append(probability)
Expand Down Expand Up @@ -492,10 +549,19 @@ class InContextLearningLMExpectedCalibrationError(InContextLearningExpectedCalib
# Make torchmetrics call update only once
full_state_update = False

def update(self, batch: Dict[str, Any], output_logits: torch.Tensor, labels: torch.Tensor):
output_logits = torch.softmax(output_logits, dim=2)
def update(self,
batch: dict,
output_logits: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
outputs: Optional[torch.Tensor] = None):
batch, outputs, labels = InContextLearningMetric.rename_args(batch=batch,
output_logits=output_logits,
labels=labels,
outputs=outputs)

outputs = torch.softmax(outputs, dim=2)
for batch_idx, cont_idx in enumerate(batch['continuation_indices']):
cont_tok_logits = output_logits[batch_idx].index_select(dim=0, index=cont_idx - 1)
cont_tok_logits = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1)
cont_tok_pred = cont_tok_logits.argmax(dim=-1)
confidence = cont_tok_logits.max(dim=-1).values.min()
cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1)
Expand Down
11 changes: 3 additions & 8 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
from torchmetrics import Metric

from composer.metrics import InContextLearningMetric, InContextLearningQAAccuracy
from composer.models.base import ComposerModel
from composer.utils import MissingConditionalImportError, dist, get_file, import_object, is_model_fsdp, safe_torch_load

Expand Down Expand Up @@ -532,14 +531,10 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]:
return metrics if metrics else {}

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
if isinstance(metric, InContextLearningQAAccuracy):
assert self.labels is not None
metric.update(batch=batch, outputs=outputs, labels=self.labels) # pyright: ignore [reportGeneralTypeIssues]
elif isinstance(metric, InContextLearningMetric):
assert self.labels is not None
metric.update(batch, outputs, self.labels) # pyright: ignore [reportGeneralTypeIssues]
if getattr(metric, 'needs_batch', False):
metric.update(batch=batch, outputs=outputs, labels=self.labels)
else:
metric.update(outputs, self.labels) # pyright: ignore [reportGeneralTypeIssues]
metric.update(outputs, self.labels)

def get_metadata(self):
model_output = {}
Expand Down
13 changes: 11 additions & 2 deletions composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pathlib
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union

from composer.core import Callback
from composer.loggers import Logger
from composer.profiler.json_trace_handler import JSONTraceHandler
from composer.profiler.marker import Marker
from composer.profiler.profiler_action import ProfilerAction
Expand All @@ -18,14 +20,14 @@
from composer.utils import ensure_tuple, parse_uri

if TYPE_CHECKING:
from composer.core import Callback, State
from composer.core import State

__all__ = ['Profiler']

log = logging.getLogger(__name__)


class Profiler:
class Profiler(Callback):
"""Composer Profiler.
See the :doc:`Profiling Guide </trainer/performance_tutorials/profiling>` for additional information.
Expand Down Expand Up @@ -118,6 +120,8 @@ def __init__(
self.schedule = schedule
self.state = None
self._callbacks: List[Callback] = []
# Used to count skip_first starting from resumption timestamp
self.resumption_batch_idx: int = 0
self.remote_filenames: List[str] = []
# First, add each remote file name to self.remote_filenames to create RemoteUploaderDownloader logger in trainer. [s3://bucket/path/to/file]
# Then modify remote file name to be a local path to pass into torch_profiler and system_profiler. e.g: path/to/file
Expand Down Expand Up @@ -185,6 +189,7 @@ def bind_to_state(
state (State): The training state.
"""
self.state = state
self.state.callbacks.append(self)
self.state.callbacks.extend(self._callbacks)
self.state.callbacks.extend(self._trace_handlers)

Expand Down Expand Up @@ -289,3 +294,7 @@ def should_record(state: State) -> bool:
)
self._names_to_markers[name].categories = categories
return self._names_to_markers[name]

def after_load(self, state: State, logger: Logger) -> None:
del logger
self.resumption_batch_idx = int(state.timestamp.batch_in_epoch)
Loading

0 comments on commit 7fbabd9

Please sign in to comment.