Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
artemorloff committed Aug 25, 2024
1 parent aab42ba commit 6dd81ae
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 24 deletions.
26 changes: 22 additions & 4 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
parser.add_argument(
"--filter_device",
type=str,
default=None,
help="Device to use (e.g. cuda, cuda:0, cpu) for models in filters. By default, equals to LM device.",
)
default_seed_string = "0,1234,1234,1234"
parser.add_argument(
"--seed",
Expand Down Expand Up @@ -295,15 +301,22 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
)

if args.fewshot_as_multiturn and args.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
# no chat_template means definitely no multiturn
args.fewshot_as_multiturn = False
eval_logger.warning(
(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
"Disabling `fewshot_as_multiturn`."
)
)

if (
args.num_fewshot is None or args.num_fewshot == 0
) and args.fewshot_as_multiturn:
raise ValueError(
"If fewshot_as_multiturn is set, num_fewshot must be greater than 0."
# if no fewshots, multiturn has no sense, disable it
args.fewshot_as_multiturn = False
eval_logger.warning(
"If fewshot_as_multiturn is set, num_fewshot must be greater than 0. Disabling `fewshot_as_multiturn`."
)

if args.include_path is not None:
Expand All @@ -321,6 +334,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)

if args.filter_device is None:
# no need to spam warning, default behaviour
args.filter_device = args.device

if args.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
Expand Down Expand Up @@ -411,6 +428,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
filter_device=args.filter_device,
**request_caching_args,
)

Expand Down
8 changes: 5 additions & 3 deletions lm_eval/api/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def __init__(self, **kwargs) -> None:
"""

@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
def apply(
self, resps: Union[List, Iterable], docs: List[dict], **kwargs
) -> Iterable:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
Expand All @@ -42,13 +44,13 @@ class FilterEnsemble:
name: str
filters: List[Callable[[], Filter]]

def apply(self, instances: List[Instance]) -> None:
def apply(self, instances: List[Instance], **kwargs) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)

for f in self.filters:
# apply filters in sequence
resps = f().apply(resps, docs)
resps = f().apply(resps, docs, **kwargs)

# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
Expand Down
44 changes: 44 additions & 0 deletions lm_eval/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,30 @@ def f1_score(items):
return np.max(fscore)


@register_aggregation("f1_macro")
def f1_macro_score(items):
from sklearn.metrics import f1_score

unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="macro")

return np.max(fscore)


@register_aggregation("f1_micro")
def f1_micro_score(items):
from sklearn.metrics import f1_score

unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="micro")

return np.max(fscore)


@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
from sklearn.metrics import matthews_corrcoef
Expand Down Expand Up @@ -319,6 +343,26 @@ def f1_fn(items): # This is a passthrough function
return items


@register_metric(
metric="f1_macro",
higher_is_better=True,
output_type="multiple_choice",
aggregation="f1_macro",
)
def f1_macro_fn(items): # This is a passthrough function
return items


@register_metric(
metric="f1_micro",
higher_is_better=True,
output_type="multiple_choice",
aggregation="f1_micro",
)
def f1_micro_fn(items): # This is a passthrough function
return items


@register_metric(
metric="bleu",
higher_is_better=True,
Expand Down
8 changes: 4 additions & 4 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,11 @@ def fewshot_context(
example = self.doc_to_text(doc)
return description + labeled_examples + example

def apply_filters(self) -> Optional[List[Instance]]:
def apply_filters(self, **kwargs) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
f.apply(self._instances, **kwargs)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
Expand Down Expand Up @@ -1131,11 +1131,11 @@ def fewshot_context(
else:
return labeled_examples + str(example)

def apply_filters(self):
def apply_filters(self, **kwargs):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
f.apply(self._instances, **kwargs)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
Expand Down
12 changes: 11 additions & 1 deletion lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def simple_evaluate(
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
filter_device: Optional[str] = None,
):
"""Instantiate and evaluate a model on a list of tasks.
Expand All @@ -92,6 +93,8 @@ def simple_evaluate(
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param filter_device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running additional models from filters (e.g. reward models)
:param use_cache: str, optional
A path to a sqlite db file for caching model responses. `None` if not caching.
:param cache_requests: bool, optional
Expand Down Expand Up @@ -309,6 +312,8 @@ def _adjust_config(task_dict):
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
verbosity=verbosity,
predict_only=predict_only,
filter_device=filter_device,
)

if lm.rank == 0:
Expand Down Expand Up @@ -368,6 +373,8 @@ def evaluate(
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
verbosity: str = "INFO",
predict_only: bool = False,
filter_device: Optional[str] = None,
):
"""Instantiate and evaluate a model on a list of tasks.
Expand All @@ -392,6 +399,8 @@ def evaluate(
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param filter_device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running additional models from filters (e.g. reward models)
:return
Dictionary of results
"""
Expand Down Expand Up @@ -486,7 +495,8 @@ def evaluate(
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_output in eval_tasks:
task = task_output.task
task.apply_filters()
# no need to run reward models when `predict_only=True`
task.apply_filters(predict_only=predict_only, filter_device=filter_device)

### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/filters/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self.group_select = group_select
self.fallback = fallback

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
Expand Down Expand Up @@ -58,7 +58,7 @@ class WhitespaceFilter(Filter):
def __init__(self) -> None:
pass

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
def filter_set(inst):
filtered_resp = []
for resp in inst:
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/filters/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self) -> None:
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
Expand All @@ -30,7 +30,7 @@ def __init__(self, **kwargs) -> None:

super().__init__(**kwargs)

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
Expand All @@ -47,7 +47,7 @@ def __init__(self) -> None:
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/filters/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class LowercaseFilter(Filter):
def __init__(self) -> None:
pass

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
def filter_set(inst):
return [resp.lower() for resp in inst]

Expand All @@ -19,7 +19,7 @@ class UppercaseFilter(Filter):
def __init__(self) -> None:
pass

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
def filter_set(inst):
return [resp.upper() for resp in inst]

Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
self.mapping_dict = mapping_dict
self.default_value = default_value

def apply(self, resps, docs):
def apply(self, resps, docs, **kwargs):
def filter_set(inst):
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]

Expand Down
12 changes: 9 additions & 3 deletions lm_eval/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

try:
import requests
from aiohttp import ClientSession, TCPConnector
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
Expand Down Expand Up @@ -76,6 +76,8 @@ def __init__(
custom_prefix_token_id=None,
# send the requests as tokens or strings
tokenized_requests=True,
# timeout for async client
timeout: Optional[int] = None,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(
self.custom_prefix_token_id = custom_prefix_token_id
self.tokenized_requests = tokenized_requests
self.max_retries = int(max_retries)
self.timeout = timeout

eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None:
Expand Down Expand Up @@ -249,7 +252,8 @@ def apply_chat_template(
)
else:
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(json.dumps(chat_history))
# to store cyrillic symbols disable `ensure_ascii`
return JsonChatStr(json.dumps(chat_history), ensure_ascii=False)

@cached_property
def eot_token_id(self) -> Optional[int]:
Expand Down Expand Up @@ -438,7 +442,9 @@ async def get_batched_requests(
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent)
async with ClientSession(connector=conn) as session:
async with ClientSession(
connector=conn, timeout=ClientTimeout(total=self.timeout)
) as session:
retry_: Callable[..., Awaitable[Any]] = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
Expand Down

0 comments on commit 6dd81ae

Please sign in to comment.