Skip to content

Commit

Permalink
add encodings test, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Sep 21, 2023
1 parent 98a8dea commit d48519d
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 23 deletions.
3 changes: 2 additions & 1 deletion elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract
from .extraction import Extract, extract, get_encodings
from .generator import _GeneratorBuilder, _GeneratorConfig
from .inference_server import InferenceServer
from .prompt_loading import get_prompter, load_prompts
Expand All @@ -14,4 +14,5 @@
"_GeneratorBuilder",
"load_prompts",
"get_prompter",
"get_encodings",
]
34 changes: 17 additions & 17 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ class Extract(Serializable):
get_lm_preds: bool = True
"""Whether to extract the LM predictions."""

binarize: bool = False
"""Whether to binarize the dataset labels for multi-class datasets."""

int8: bool = False
"""Whether to perform inference in mixed int8 precision with `bitsandbytes`."""

Expand Down Expand Up @@ -148,14 +145,14 @@ def explode(self) -> list["Extract"]:
]


@torch.inference_mode()
def get_encodings(
cfg: "Extract",
split_type: Literal["train", "val"] = "train",
) -> Dataset:
"""Apply the prompt templates to the dataset and return the tokenized LM inputs.
Each dict contains the keys `input_ids`, `attention_mask`, `labels`,
`output_hidden_states`, `variant_id`, `row_id`, `text`, and `label`.
Each dict contains the keys `input_ids`, `output_hidden_states`, `variant_id`,
`row_id`, `text`, and `label`. If lm_preds is True, we also include `answer_ids`
and `num_suffix_tokens`.
"""
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down Expand Up @@ -200,7 +197,6 @@ def get_encodings(
for i, statement in enumerate(example["statements"]):
if cfg.get_lm_preds:
suffix = example["suffixes"][i]
text = statement + suffix
answer_choices = example["answer_choices"][i]
assert len(answer_choices) == 2
answer_ids = []
Expand All @@ -213,39 +209,43 @@ def get_encodings(
"first token only."
)
answer_ids.append(a_id[0])
num_suffix_tokens = len(
tokenizer.encode(suffix, add_special_tokens=False)
)
else:
text = statement
suffix = ""

suffix_tokens = torch.tensor(
tokenizer.encode(suffix, add_special_tokens=False),
dtype=torch.long,
)

encoding = tokenizer(
text,
statement,
# Keep [CLS] and [SEP] for BERT-style models
add_special_tokens=True,
return_tensors="pt",
)

ids = assert_type(Tensor, encoding.input_ids)
# suffix comes right after the last statement token, before the answer
ids = torch.cat([encoding.input_ids, suffix_tokens])

# If this input is too long, skip it
if ids.shape[-1] > max_length:
any_too_long = True
break

inputs: dict[str, Tensor | None | bool] = dict(input_ids=ids.long())
inputs["output_hidden_states"] = True
inputs: dict[str, Tensor | None] = dict(input_ids=ids.long())

out_record: dict[str, Any] = dict(
row_id=example["row_id"],
variant_id=example["template_names"][i],
label=example["label"],
text=text,
text=statement + suffix,
output_hidden_states=True,
**inputs,
)
if cfg.get_lm_preds:
out_record["answer_ids"] = answer_ids # type: ignore
out_record["num_suffix_tokens"] = num_suffix_tokens # type: ignore
# keep track of where to extract hiddens from
out_record["num_suffix_tokens"] = len(suffix_tokens)
record_variants.append(out_record)

if any_too_long:
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_config_id(
config_kwargs["gen_kwargs"] = {
k: v[0]
for k, v in config_kwargs.get("gen_kwargs", {}).items()
if k not in ("device", "rank", "world_size", "server")
if k not in ("device", "rank", "world_size", "server", "fsdp", "int8")
}
config_kwargs.pop("generator") # pickling InferenceServer fails
return super().create_config_id(config_kwargs, custom_features)
Expand Down
2 changes: 1 addition & 1 deletion elk/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_fixed_answer_choices_list(self):
else:
return None

def apply(self, example, truncate=True, highlight_variables=False):
def apply(self, example, truncate=False, highlight_variables=False):
"""
Creates a prompt by applying this template to an example
Expand Down
6 changes: 3 additions & 3 deletions elk/promptsource/templates/_default/templates.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
dataset: azaria_mitchell
dataset: None
templates:
7eab7254-bd71-4b1d-9f8a-0fc7110f8371: !Template
answer_choices: False ||| True
id: 7eab7254-bd71-4b1d-9f8a-0fc7110f8371
id: 7eab7254-bd41-4b1d-9f8a-0fc7110f8371
jinja: "{{ statement }}"
metadata: !TemplateMetadata
choices_in_prompt: true
Expand All @@ -11,5 +11,5 @@ templates:
metrics:
- Accuracy
original_task: true
name: 7eab7254-bd71-4b1d-9f8a-0fc7110f8371
name: _default
reference: ''
53 changes: 53 additions & 0 deletions tests/test_encodings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from elk.extraction import Extract, get_encodings


def test_get_encodings():
dataset_name = "imdb"
model_path = "sshleifer/tiny-gpt2"

seed = 42
cfg = Extract(
model=model_path,
datasets=(dataset_name,),
max_examples=(10, 10),
template_path="_default",
get_lm_preds=True,
statement_column="text",
balance=False,
seed=seed,
)
split_type = "train"
encodings = get_encodings(cfg, split_type)

tokenizer = AutoTokenizer.from_pretrained(model_path, truncation_side="left")
ds = load_dataset(dataset_name, split=split_type)
ds = ds.add_column("row_id", range(len(ds))) # type: ignore
ds = ds.shuffle(seed=seed).select(range(10)) # type: ignore

def map_fn(ex: dict) -> dict:
out_record = {
"row_id": ex["row_id"],
"label": ex["label"],
"variant_id": "_default",
"text": ex["text"],
"num_suffix_tokens": 0,
"output_hidden_states": True, # TODO: we might remove this
}
input_ids = [tokenizer(ex["text"], add_special_tokens=True)["input_ids"]]
out_record["input_ids"] = input_ids
answer_ids = [
tokenizer.encode(s, add_special_tokens=False)[0] for s in ["False", "True"]
]
out_record["answer_ids"] = answer_ids
return out_record

ds = ds.map(map_fn, batched=False, remove_columns=ds.column_names, num_proc=1)
gt_ds = ds.filter(lambda ex: len(ex["input_ids"]) <= tokenizer.model_max_length)

assert len(encodings) == len(gt_ds)
assert set(encodings.column_names) == set(gt_ds.column_names)
for col in encodings.column_names:
assert encodings[col] == gt_ds[col]

0 comments on commit d48519d

Please sign in to comment.