Skip to content

Commit

Permalink
clean up PR
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Jun 22, 2023
1 parent 60dfc55 commit 4b31ff0
Show file tree
Hide file tree
Showing 15 changed files with 739 additions and 837 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

try:
from llmfoundry.callbacks.eval_taxonomy_callback import EvalTaxonomy
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.generate_callback import Generate
from llmfoundry.callbacks.model_gauntlet_callback import ModelGauntlet
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
Expand All @@ -18,5 +18,5 @@

__all__ = [
'FDiffMetrics', 'Generate', 'MonolithicCheckpointSaver', 'GlobalLRScaling',
'LayerFreezing', 'ScheduledGarbageCollector', 'EvalTaxonomy'
'LayerFreezing', 'ScheduledGarbageCollector', 'ModelGauntlet'
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,58 @@

"""Monitor gradients during training."""

from enum import Enum
import math
import re
from typing import Optional

import torch
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist

__all__ = ['EvalTaxonomy']
__all__ = ['MoedlGauntlet']


class EvalTaxonomy(Callback):
class Weighting(Enum):
EQUAL = 1
SAMPLE_SZ = 2
LOG_SAMPLE_SZ = 3
class MoedlGauntlet(Callback):

def __init__(
self,
logger_keys: dict,
tasks: dict,
equal_weighting: bool = True,
weighting: Weighting = Weighting.EQUAL,
subtract_random_baseline: bool = True,
rescale_accuracy: bool = True,
benchmark_sizes: Optional[dict] = None
):
self.tasks = tasks
self.equal_weighting = equal_weighting
self.weighting = Weighting[weighting]
self.subtract_random_baseline = subtract_random_baseline
self.rescale_accuracy = rescale_accuracy
self.logger_keys = logger_keys
for category in self.tasks:
if self.equal_weighting:
for benchmark in category['benchmarks']:
benchmark['weighting'] = 1
else:
for benchmark in category['benchmarks']:
benchmark['weighting'] = sum([
v for k, v in benchmark['scorecard'].items()
if k != 'random_baseline'
])

for benchmark in category['benchmarks']:
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
cumulative_samples = max(
sum(count for name,count in benchmark_sizes.items() if name.startswith(bench_name)),
1)

if self.weighting == Weighting.EQUAL:
weight = 1
elif self.weighting == Weighting.SAMPLE_SZ:
weight = cumulative_samples
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
weight = max(
math.log(cumulative_samples, 2),
1
)

benchmark['weighting'] = weight

def compute_averages(self, logger_data):

results = {}
Expand Down
Loading

0 comments on commit 4b31ff0

Please sign in to comment.