Skip to content

Commit

Permalink
fix VII: the fix awakens
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 19, 2024
1 parent b7fb56a commit 1eb809e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 47 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def build_finetuning_dataloader(
given a starting workload YAML.
"""
dataset_cfg = dataset
_validate_config(dataset_cfg)
_validate_config(**dataset_cfg)

# Use EOS as the pad token if none exists
if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok)
Expand Down
88 changes: 42 additions & 46 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from composer.models import ComposerModel
from composer.optim.scheduler import ComposerScheduler
from composer.utils import dist
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
Expand Down Expand Up @@ -457,7 +456,7 @@ def build_tokenizer(


def build_icl_evaluators(
icl_tasks: Union[str, ListConfig],
icl_tasks: Union[str, List[Dict[str, Any]]],
tokenizer: PreTrainedTokenizerBase,
default_max_seq_len: int,
default_batch_size: int,
Expand All @@ -475,78 +474,79 @@ def build_icl_evaluators(
log.info(f'Extracting ICL task config from path: {icl_tasks}')
with open(icl_tasks, 'r') as icl_f:
icl_task_cfg = om.load(icl_f)
icl_tasks_list = icl_task_cfg.icl_tasks
icl_tasks_list = to_str_dict(icl_task_cfg.icl_tasks)
else:
icl_tasks_list = icl_tasks

def _validate_cfg(icl_cfg: DictConfig):
def _validate_cfg(icl_cfg: Dict[str, Any]):
assert 'label' in icl_cfg
assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None
assert 'dataset_uri' in icl_cfg and icl_cfg['dataset_uri'] is not None
assert 'icl_task_type' in icl_cfg
assert 'num_fewshot' in icl_cfg

if 'metric_names' not in icl_cfg:
if icl_cfg.icl_task_type == 'language_modeling':
icl_cfg.metric_names = ['InContextLearningLMAccuracy']
elif icl_cfg.icl_task_type == 'multiple_choice':
icl_cfg.metric_names = [
if icl_cfg['icl_task_type'] == 'language_modeling':
icl_cfg['metric_names'] = ['InContextLearningLMAccuracy']
elif icl_cfg['icl_task_type'] == 'multiple_choice':
icl_cfg['metric_names'] = [
'InContextLearningMultipleChoiceAccuracy'
]
elif icl_cfg.icl_task_type == 'schema':
icl_cfg.metric_names = [
elif icl_cfg['icl_task_type'] == 'schema':
icl_cfg['metric_names'] = [
'InContextLearningMultipleChoiceAccuracy'
]
elif icl_cfg.icl_task_type == 'generation_task_with_answers' or icl_cfg.icl_task_type == 'question_answering':
if icl_cfg.icl_task_type == 'question_answering':
elif icl_cfg[
'icl_task_type'] == 'generation_task_with_answers' or icl_cfg.icl_task_type == 'question_answering':
if icl_cfg['icl_task_type'] == 'question_answering':
warnings.warn(
VersionedDeprecationWarning(
"ICL task type 'question_answering' is now deprecated. Use identifier 'generation_task_with_answers'",
'v0.9.0'))
icl_cfg.metric_names = [
icl_cfg['metric_names'] = [
'InContextLearningGenerationExactMatchAccuracy'
]
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
elif icl_cfg['icl_task_type'] == 'code_evaluation':
icl_cfg['metric_names'] = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(
f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.'
)

if 'prompt_string' not in icl_cfg:
icl_cfg.prompt_string = ''
icl_cfg['prompt_string'] = ''
if 'example_delimiter' not in icl_cfg:
icl_cfg.example_delimiter = '\n'
icl_cfg['example_delimiter'] = '\n'
if 'continuation_delimiter' not in icl_cfg:
icl_cfg.continuation_delimiter = ' '
icl_cfg['continuation_delimiter'] = ' '
if 'max_seq_len' not in icl_cfg:
icl_cfg.max_seq_len = default_max_seq_len
icl_cfg['max_seq_len'] = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
icl_cfg['batch_size'] = default_batch_size
if 'pass_at_k' not in icl_cfg:
icl_cfg.pass_at_k = 1
icl_cfg['pass_at_k'] = 1
if 'fewshot_random_seed' not in icl_cfg:
icl_cfg.fewshot_random_seed = 1234
icl_cfg['fewshot_random_seed'] = 1234
if 'generations_per_sample' not in icl_cfg:
icl_cfg.generations_per_sample = 1
icl_cfg['generations_per_sample'] = 1

if 'num_beams' in icl_cfg:
raise ValueError(
'num_beams is no longer supported as a top level icl_task parameter.' + \
'Please use generation_kwargs.num_beams instead.')

for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
assert isinstance(icl_cfg, dict)
_validate_cfg(icl_cfg)
for num_fewshot in list(icl_cfg.num_fewshot):
for num_fewshot in list(icl_cfg['num_fewshot']):
if tokenizer.pad_token_id is None:
# Current workaround to support GPT2 tokenizer with `pad_token_id = None`
pad_tok_id = tokenizer.eos_token_id
else:
pad_tok_id = tokenizer.pad_token_id
label = f'{icl_cfg.label}/{num_fewshot}-shot'
metric_names = list(icl_cfg.metric_names)
label = f'{icl_cfg["label"]}/{num_fewshot}-shot'
metric_names = list(icl_cfg['metric_names'])
# TODO: fix Composer bug when copying local paths and destination exists
destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
destination_path = f'{destination_dir}/{icl_cfg["label"]}-{num_fewshot}.jsonl'
if dist.get_local_rank() == 0 and os.path.exists(destination_path):
os.remove(destination_path)
dist.barrier()
Expand All @@ -556,38 +556,34 @@ def _validate_cfg(icl_cfg: DictConfig):

early_stopping_criteria = icl_cfg.get('early_stopping_criteria',
None)
if isinstance(early_stopping_criteria, ListConfig):
early_stopping_criteria = om.to_container(
early_stopping_criteria)
assert early_stopping_criteria is None or isinstance(
early_stopping_criteria, list)
dataloaders = get_icl_task_dataloader(
icl_cfg.icl_task_type,
icl_cfg.dataset_uri,
icl_cfg['icl_task_type'],
icl_cfg['dataset_uri'],
tokenizer,
batch_size=icl_cfg.batch_size,
max_seq_len=icl_cfg.max_seq_len,
batch_size=icl_cfg['batch_size'],
max_seq_len=icl_cfg['max_seq_len'],
pad_tok_id=pad_tok_id,
num_fewshot=num_fewshot,
prompt_string=icl_cfg.prompt_string,
example_delimiter=icl_cfg.example_delimiter,
prompt_string=icl_cfg['prompt_string'],
example_delimiter=icl_cfg['example_delimiter'],
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
continuation_delimiter=icl_cfg.continuation_delimiter,
continuation_delimiter=icl_cfg['continuation_delimiter'],
question_prelimiter=icl_cfg.get('question_prelimiter', ''),
destination_path=destination_path,
fewshot_random_seed=icl_cfg.fewshot_random_seed,
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.generations_per_sample,
fewshot_random_seed=icl_cfg['fewshot_random_seed'],
pass_at_k=icl_cfg['pass_at_k'],
generations_per_sample=icl_cfg['generations_per_sample'],
has_categories=icl_cfg.get('has_categories', False),
cot_delimiter=icl_cfg.get('cot_delimiter', ''),
generation_kwargs=icl_cfg.get('generation_kwargs', {}),
early_stopping_criteria=early_stopping_criteria,
do_normalization=icl_cfg.get('do_normalization', True))
if hasattr(
icl_cfg,
'has_categories') and icl_cfg.has_categories and isinstance(
dataloaders, dict):
if hasattr(icl_cfg, 'has_categories'
) and icl_cfg['has_categories'] and isinstance(
dataloaders, dict):
for category in dataloaders.keys():
logger_keys.extend([
f'metrics/{label}/{category}/{m}' for m in metric_names
Expand Down

0 comments on commit 1eb809e

Please sign in to comment.