diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index 5144566f..8a305ae1 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -1,5 +1,5 @@ from .balanced_sampler import BalancedSampler, FewShotSampler -from .extraction import Extract, extract +from .extraction import Extract, extract, get_encodings from .generator import _GeneratorBuilder, _GeneratorConfig from .inference_server import InferenceServer from .prompt_loading import get_prompter, load_prompts @@ -14,4 +14,5 @@ "_GeneratorBuilder", "load_prompts", "get_prompter", + "get_encodings", ] diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 9d8c73e4..7463b397 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -58,9 +58,6 @@ class Extract(Serializable): get_lm_preds: bool = True """Whether to extract the LM predictions.""" - binarize: bool = False - """Whether to binarize the dataset labels for multi-class datasets.""" - int8: bool = False """Whether to perform inference in mixed int8 precision with `bitsandbytes`.""" @@ -148,14 +145,14 @@ def explode(self) -> list["Extract"]: ] -@torch.inference_mode() def get_encodings( cfg: "Extract", split_type: Literal["train", "val"] = "train", ) -> Dataset: """Apply the prompt templates to the dataset and return the tokenized LM inputs. - Each dict contains the keys `input_ids`, `attention_mask`, `labels`, - `output_hidden_states`, `variant_id`, `row_id`, `text`, and `label`. + Each dict contains the keys `input_ids`, `output_hidden_states`, `variant_id`, + `row_id`, `text`, and `label`. If lm_preds is True, we also include `answer_ids` + and `num_suffix_tokens`. """ os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -200,7 +197,6 @@ def get_encodings( for i, statement in enumerate(example["statements"]): if cfg.get_lm_preds: suffix = example["suffixes"][i] - text = statement + suffix answer_choices = example["answer_choices"][i] assert len(answer_choices) == 2 answer_ids = [] @@ -213,39 +209,43 @@ def get_encodings( "first token only." ) answer_ids.append(a_id[0]) - num_suffix_tokens = len( - tokenizer.encode(suffix, add_special_tokens=False) - ) else: - text = statement + suffix = "" + + suffix_tokens = torch.tensor( + tokenizer.encode(suffix, add_special_tokens=False), + dtype=torch.long, + ) encoding = tokenizer( - text, + statement, # Keep [CLS] and [SEP] for BERT-style models add_special_tokens=True, return_tensors="pt", ) - ids = assert_type(Tensor, encoding.input_ids) + # suffix comes right after the last statement token, before the answer + ids = torch.cat([encoding.input_ids, suffix_tokens]) # If this input is too long, skip it if ids.shape[-1] > max_length: any_too_long = True break - inputs: dict[str, Tensor | None | bool] = dict(input_ids=ids.long()) - inputs["output_hidden_states"] = True + inputs: dict[str, Tensor | None] = dict(input_ids=ids.long()) out_record: dict[str, Any] = dict( row_id=example["row_id"], variant_id=example["template_names"][i], label=example["label"], - text=text, + text=statement + suffix, + output_hidden_states=True, **inputs, ) if cfg.get_lm_preds: out_record["answer_ids"] = answer_ids # type: ignore - out_record["num_suffix_tokens"] = num_suffix_tokens # type: ignore + # keep track of where to extract hiddens from + out_record["num_suffix_tokens"] = len(suffix_tokens) record_variants.append(out_record) if any_too_long: diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 9e6e2777..06ccbe06 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -30,7 +30,7 @@ def create_config_id( config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() - if k not in ("device", "rank", "world_size", "server") + if k not in ("device", "rank", "world_size", "server", "fsdp", "int8") } config_kwargs.pop("generator") # pickling InferenceServer fails return super().create_config_id(config_kwargs, custom_features) diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index 60a311d9..1d5fd86b 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -138,7 +138,7 @@ def get_fixed_answer_choices_list(self): else: return None - def apply(self, example, truncate=True, highlight_variables=False): + def apply(self, example, truncate=False, highlight_variables=False): """ Creates a prompt by applying this template to an example diff --git a/elk/promptsource/templates/_default/templates.yaml b/elk/promptsource/templates/_default/templates.yaml index 26206e70..6240f650 100644 --- a/elk/promptsource/templates/_default/templates.yaml +++ b/elk/promptsource/templates/_default/templates.yaml @@ -1,8 +1,8 @@ -dataset: azaria_mitchell +dataset: None templates: 7eab7254-bd71-4b1d-9f8a-0fc7110f8371: !Template answer_choices: False ||| True - id: 7eab7254-bd71-4b1d-9f8a-0fc7110f8371 + id: 7eab7254-bd41-4b1d-9f8a-0fc7110f8371 jinja: "{{ statement }}" metadata: !TemplateMetadata choices_in_prompt: true @@ -11,5 +11,5 @@ templates: metrics: - Accuracy original_task: true - name: 7eab7254-bd71-4b1d-9f8a-0fc7110f8371 + name: _default reference: '' diff --git a/tests/test_encodings.py b/tests/test_encodings.py new file mode 100644 index 00000000..9f8a2155 --- /dev/null +++ b/tests/test_encodings.py @@ -0,0 +1,53 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +from elk.extraction import Extract, get_encodings + + +def test_get_encodings(): + dataset_name = "imdb" + model_path = "sshleifer/tiny-gpt2" + + seed = 42 + cfg = Extract( + model=model_path, + datasets=(dataset_name,), + max_examples=(10, 10), + template_path="_default", + get_lm_preds=True, + statement_column="text", + balance=False, + seed=seed, + ) + split_type = "train" + encodings = get_encodings(cfg, split_type) + + tokenizer = AutoTokenizer.from_pretrained(model_path, truncation_side="left") + ds = load_dataset(dataset_name, split=split_type) + ds = ds.add_column("row_id", range(len(ds))) # type: ignore + ds = ds.shuffle(seed=seed).select(range(10)) # type: ignore + + def map_fn(ex: dict) -> dict: + out_record = { + "row_id": ex["row_id"], + "label": ex["label"], + "variant_id": "_default", + "text": ex["text"], + "num_suffix_tokens": 0, + "output_hidden_states": True, # TODO: we might remove this + } + input_ids = [tokenizer(ex["text"], add_special_tokens=True)["input_ids"]] + out_record["input_ids"] = input_ids + answer_ids = [ + tokenizer.encode(s, add_special_tokens=False)[0] for s in ["False", "True"] + ] + out_record["answer_ids"] = answer_ids + return out_record + + ds = ds.map(map_fn, batched=False, remove_columns=ds.column_names, num_proc=1) + gt_ds = ds.filter(lambda ex: len(ex["input_ids"]) <= tokenizer.model_max_length) + + assert len(encodings) == len(gt_ds) + assert set(encodings.column_names) == set(gt_ds.column_names) + for col in encodings.column_names: + assert encodings[col] == gt_ds[col]