Skip to content

Commit

Permalink
Merge branch 'main' into release/v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 26, 2023
2 parents 2a9b8ac + fd36398 commit 0ef130a
Show file tree
Hide file tree
Showing 12 changed files with 605 additions and 7 deletions.
4 changes: 3 additions & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningLMAccuracy,
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(self, om_model_config: Union[DictConfig,
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics import (InContextLearningLMAccuracy,
from composer.metrics import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -700,6 +701,7 @@ def __init__(
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError(),
]
Expand Down
8 changes: 8 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _validate_cfg(icl_cfg: DictConfig):
]
elif icl_cfg.icl_task_type == 'question_answering':
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(
f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.'
Expand All @@ -244,6 +246,10 @@ def _validate_cfg(icl_cfg: DictConfig):
icl_cfg.max_seq_len = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
if 'pass_at_k' not in icl_cfg:
icl_cfg.pass_at_k = 1
if 'num_beams' not in icl_cfg:
icl_cfg.num_beams = 20

for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
Expand Down Expand Up @@ -274,6 +280,8 @@ def _validate_cfg(icl_cfg: DictConfig):
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
destination_path=destination_path,
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.num_beams,
has_categories=icl_cfg.get('has_categories', False),
)
if hasattr(
Expand Down
35 changes: 30 additions & 5 deletions scripts/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,13 @@ This document explains the ICL formats compatible with [Composer](https://github

## Supported ICL formats

Composer currently supports four ICL formats
Composer currently supports five ICL formats:

1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L92-L253)
2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L256-L402)
3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L405-L599)
4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/v0.14.0/composer/datasets/in_context_learning_evaluation.py#L602-L773)
1. [InContextLearningQATaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L103)
2. [InContextLearningLMTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L293)
3. [InContextLearningMultipleChoiceTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L444)
4. [InContextLearningSchemaTaskDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L676)
5. [InContextLearningCodeEvalDataset](https://github.com/mosaicml/composer/blob/336bf8db3e2c09ae942d4bf8a819935106589d1a/composer/datasets/in_context_learning_evaluation.py#L852)

----

Expand Down Expand Up @@ -346,6 +347,30 @@ Below is a YAML section that works with the Winograd dataset in [`scripts/eval/l
continuation_delimiter: ' ' # this separates questions from answers
>

----

### InContextLearningCodeEvalDataset

The ICL CodeEvalDataset takes a prompt, and, working with the NLP metric [InContextLearningCodeEvalAccuracy](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.metrics.InContextLearningCodeEvalAccuracy.html), generates code which gets run against the supplied tests, as in HumanEval ([Evaluating Large Language Models Trained on Code](https://arxiv.org/abs/2107.03374)) and MBPP ([Program Synthesis with Large Language Models](https://arxiv.org/abs/2108.07732)). This generation involves many decoding steps, so can take longer per sample than other ICL tasks. An example datum:

```json
{"task_id": "JavaScript/2", "prompt": "/* Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncateNumber(3.5)\n 0.5\n */\nconst truncateNumber = (number) => {\n", "canonical_solution": " return number % 1.0;\n}\n\n", "test": "const testTruncateNumber = () => {\n console.assert(truncateNumber(3.5) === 0.5)\n\n console.assert(Math.abs(truncateNumber(1.33) - 0.33) < 1e-6)\n\n console.assert(Math.abs(truncateNumber(123.456 - 0.456) < 1e-6))\n}\n\ntestTruncateNumber()\n", "entry_point": "truncateNumber", "test_inputs": ["3.5", "1.33", "123.456"], "test_outputs": ["0.5", "0.33", "0.456"], "language": "javascript"}
```

Required keys for each datum:

* `prompt: str`
* `test: str`
* `entry_point: str`
* `test_inputs: List[str]`
* `test_outputs: List[str]`
* `language: str`

Code evaluation can happen locally (insecure) or inside an AWS Lambda function sandbox. This is controlled by setting the environment variable `CODE_EVAL_DEVICE` to `LOCAL` or `LAMBDA`. If set to `LAMBDA`, you must also provide `CODE_EVAL_URL` and `CODE_EVAL_APIKEY` to query the API gateway in the AWS Sandbox.

----

### Build your own dataset (BYOD)
Building a dataset compatible with our eval suite is very easy if it fits with one of the four supported task types. Simply choose the appropriate task type (LM, MC, QA, or Schema) and process each dataset into a jsonl format in which each row has the format described above.

Expand Down
4 changes: 4 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def evaluate_model(
model_cfg: DictConfig,
dist_timeout: Union[float, int],
run_name: str,
seed: int,
icl_tasks: Union[str, ListConfig],
max_seq_len: int,
device_eval_batch_size: int,
Expand All @@ -107,6 +108,7 @@ def evaluate_model(
eval_gauntlet_df: Optional[pd.DataFrame],
icl_subset_num_batches: Optional[int],
):

print(f'Evaluating model: {model_cfg.model_name}', flush=True)
# Build tokenizer and model
tokenizer_cfg: Dict[str,
Expand Down Expand Up @@ -158,6 +160,7 @@ def evaluate_model(

trainer = Trainer(
run_name=run_name,
seed=seed,
model=composer_model,
callbacks=callbacks,
loggers=loggers,
Expand Down Expand Up @@ -276,6 +279,7 @@ def main(cfg: DictConfig):
model_cfg=model_cfg,
dist_timeout=dist_timeout,
run_name=run_name,
seed=seed,
icl_tasks=icl_tasks,
max_seq_len=max_seq_len,
device_eval_batch_size=device_eval_batch_size,
Expand Down
Loading

0 comments on commit 0ef130a

Please sign in to comment.