From dc2e060b77ef863405097cfb171e5b3eb246ff8b Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Tue, 2 May 2023 14:02:31 -0700 Subject: [PATCH 1/6] add post_init checking for at least one dataset (#237) --- elk/extraction/extraction.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 70abad64..fa3f49e5 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -100,6 +100,11 @@ class Extract(Serializable): case of encoder-decoder models.""" def __post_init__(self, layer_stride: int): + if len(self.datasets) == 0: + raise ValueError( + "Must specify at least one dataset to extract hiddens from." + ) + if len(self.max_examples) > 2: raise ValueError( "max_examples should be a list of length 0, 1, or 2," From b889473a68a84c1d32fcdcb180c13de1d8fd96e3 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Tue, 2 May 2023 15:37:54 -0700 Subject: [PATCH 2/6] `burns` shortcut dataset in sweep (#236) * Burns shortcut dataset * Amend message --- elk/training/sweep.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/elk/training/sweep.py b/elk/training/sweep.py index 49fa99e8..2b1871df 100644 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -36,6 +36,32 @@ def __post_init__(self, add_pooled: bool): if not self.models: raise ValueError("No models specified") + # Check for the magic dataset "burns" which is a shortcut for all of the + # datasets used in Burns et al., except Story Cloze, which is not available + # on the Huggingface Hub. + if "burns" in self.datasets: + self.datasets.remove("burns") + self.datasets.extend( + [ + "ag_news", + "amazon_polarity", + "dbpedia_14", + "glue:qnli", + "imdb", + "piqa", + "super_glue:boolq", + "super_glue:copa", + "super_glue:rte", + ] + ) + print( + "Interpreting `burns` as all datasets used in Burns et al. (2022) " + "available on the HuggingFace Hub" + ) + + # Remove duplicates just in case + self.datasets = sorted(set(self.datasets)) + # Add an additional dataset that pools all of the datasets together. if add_pooled: self.datasets.append("+".join(self.datasets)) From 2d88580f44efd764346c69671386cb5c53dba9bc Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Tue, 2 May 2023 19:02:49 -0700 Subject: [PATCH 3/6] Load fp32 models in bfloat16 when possible (#231) * Automatically use bfloat16 in some cases * Use bfloat16 in more cases; sanity check for int8 --- elk/extraction/extraction.py | 16 +++------------ elk/utils/__init__.py | 4 ++-- elk/utils/hf_utils.py | 38 +++++++++++++++++++++++++++++++++++- elk/utils/typing.py | 4 ++-- 4 files changed, 44 insertions(+), 18 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index fa3f49e5..0d2f7bf4 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -31,7 +31,7 @@ Color, assert_type, colorize, - float32_to_int16, + float_to_int16, infer_label_column, infer_num_classes, instantiate_model, @@ -165,20 +165,10 @@ def extract_hiddens( ds_names = cfg.datasets assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time." - if cfg.int8: - # Required by `bitsandbytes` - dtype = torch.float16 - elif device == "cpu": - dtype = torch.float32 - else: - dtype = "auto" - # We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its # welcome message on every rank with redirect_stdout(None) if rank != 0 else nullcontext(): - model = instantiate_model( - cfg.model, device_map={"": device}, load_in_8bit=cfg.int8, torch_dtype=dtype - ) + model = instantiate_model(cfg.model, device=device, load_in_8bit=cfg.int8) tokenizer = instantiate_tokenizer( cfg.model, truncation_side="left", verbose=rank == 0 ) @@ -313,7 +303,7 @@ def extract_hiddens( raise ValueError(f"Invalid token_loc: {cfg.token_loc}") for layer_idx, hidden in zip(layer_indices, hiddens): - hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden) + hidden_dict[f"hidden_{layer_idx}"][i, j] = float_to_int16(hidden) text_questions.append(variant_questions) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index d279e282..22b92b75 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -13,7 +13,7 @@ from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained from .pretty import Color, colorize from .tree_utils import pytree_map -from .typing import assert_type, float32_to_int16, int16_to_float32 +from .typing import assert_type, float_to_int16, int16_to_float32 __all__ = [ "assert_type", @@ -21,7 +21,7 @@ "Color", "colorize", "cov_mean_fused", - "float32_to_int16", + "float_to_int16", "get_columns_all_equal", "get_layer_indices", "has_multiple_configs", diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index cdf50e57..9f429921 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -1,3 +1,4 @@ +import torch import transformers from transformers import ( AutoConfig, @@ -19,10 +20,45 @@ _AUTOREGRESSIVE_SUFFIXES = ["ConditionalGeneration"] + _DECODER_ONLY_SUFFIXES -def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel: +def instantiate_model( + model_str: str, + device: str | torch.device = "cpu", + **kwargs, +) -> PreTrainedModel: """Instantiate a model string with the appropriate `Auto` class.""" + device = torch.device(device) + kwargs["device_map"] = {"": device} + with prevent_name_conflicts(): model_cfg = AutoConfig.from_pretrained(model_str) + + # When the torch_dtype is None, this generally means the model is fp32, because + # the config was probably created before the `torch_dtype` field was added. + fp32_weights = model_cfg.torch_dtype in (None, torch.float32) + + # Required by `bitsandbytes` to load in 8-bit. + if kwargs.get("load_in_8bit"): + # Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint + # is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and + # we can't guarantee that there won't be overflow if we downcast to fp16. + if fp32_weights: + raise ValueError("Cannot load in 8-bit if weights are fp32") + + kwargs["torch_dtype"] = torch.float16 + + # CPUs generally don't support anything other than fp32. + elif device.type == "cpu": + kwargs["torch_dtype"] = torch.float32 + + # If the model is fp32 but bf16 is available, convert to bf16. + # Usually models with fp32 weights were actually trained in bf16, and + # converting them doesn't hurt performance. + elif fp32_weights and torch.cuda.is_bf16_supported(): + kwargs["torch_dtype"] = torch.bfloat16 + print("Weights seem to be fp32, but bf16 is available. Loading in bf16.") + else: + kwargs["torch_dtype"] = "auto" + archs = model_cfg.architectures if not isinstance(archs, list): return AutoModel.from_pretrained(model_str, **kwargs) diff --git a/elk/utils/typing.py b/elk/utils/typing.py index f0b10d52..19a0027b 100644 --- a/elk/utils/typing.py +++ b/elk/utils/typing.py @@ -13,8 +13,8 @@ def assert_type(typ: Type[T], obj: Any) -> T: return cast(typ, obj) -def float32_to_int16(x: torch.Tensor) -> torch.Tensor: - """Converts float32 to float16, then reinterprets as int16.""" +def float_to_int16(x: torch.Tensor) -> torch.Tensor: + """Converts a floating point tensor to float16, then reinterprets as int16.""" downcast = x.type(torch.float16) if not downcast.isfinite().all(): raise ValueError("Cannot convert to 16 bit: values are not finite") From fe58c7624fea16821bb5c6f996fbc3ad67272a61 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 3 May 2023 08:00:30 +0000 Subject: [PATCH 4/6] Add use_centroids option for CRC-TPC --- elk/training/eigen_reporter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 686ab900..028a8347 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -28,6 +28,7 @@ class EigenReporterConfig(ReporterConfig): neg_cov_weight: float = 0.5 num_heads: int = 1 + use_centroids: bool = True def __post_init__(self): if not (0 <= self.neg_cov_weight <= 1): @@ -168,8 +169,13 @@ def update(self, hiddens: Tensor) -> None: intra_cov = cov_mean_fused(rearrange(hiddens, "n v k d -> (n k) v d")) self.intracluster_cov += (n / self.n) * (intra_cov - self.intracluster_cov) - # [n, v, k, d] -> [n, k, d] - centroids = hiddens.mean(1) + if self.config.use_centroids: + # VINC style + centroids = hiddens.mean(1) + else: + # CRC-TPC style + centroids = rearrange(hiddens, "n v k d -> (n v) k d") + deltas, deltas2 = [], [] # Iterating over classes From ed71d70ebd8f0b0b2eaf89e175ecf55379e3efa3 Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Wed, 3 May 2023 01:20:11 -0700 Subject: [PATCH 5/6] Reduce VINC reporter file size by >1000x (#219) * reduced reporter filesize by 4x; still unsure why the pickle file stores 1 remaining cov matrix * add save_reporter_stats CLA * VINC reporters are now >1000x smaller on disk * Temporarily disable Platt scaling --------- Co-authored-by: Nora Belrose --- elk/evaluation/evaluate.py | 2 +- elk/run.py | 13 ++-- elk/training/ccs_reporter.py | 19 +++++- elk/training/eigen_reporter.py | 110 ++++++++++++++------------------- elk/training/reporter.py | 79 +++++++++++++++++++---- elk/training/train.py | 23 +++++-- 6 files changed, 156 insertions(+), 90 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 051717f5..034bde14 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -40,7 +40,7 @@ def apply_to_layer( experiment_dir = elk_reporter_dir() / self.source reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" - reporter: Reporter = torch.load(reporter_path, map_location=device) + reporter = Reporter.load(reporter_path, map_location=device) reporter.eval() row_bufs = defaultdict(list) diff --git a/elk/run.py b/elk/run.py index 65573895..85f244c7 100644 --- a/elk/run.py +++ b/elk/run.py @@ -13,6 +13,7 @@ import torch.multiprocessing as mp import yaml from simple_parsing.helpers import Serializable, field +from simple_parsing.helpers.serialization import save from torch import Tensor from tqdm import tqdm @@ -37,12 +38,14 @@ class Run(ABC, Serializable): """Directory to save results to. If None, a directory will be created automatically.""" - datasets: list[DatasetDictWithName] = field(default_factory=list, init=False) + datasets: list[DatasetDictWithName] = field( + default_factory=list, init=False, to_dict=False + ) """Datasets containing hidden states and labels for each layer.""" concatenated_layer_offset: int = 0 debug: bool = False - min_gpu_mem: int | None = None + min_gpu_mem: int | None = None # in bytes num_gpus: int = -1 out_dir: Path | None = None disable_cache: bool = field(default=False, to_dict=False) @@ -76,9 +79,9 @@ def execute( print(f"Output directory at \033[1m{self.out_dir}\033[0m") self.out_dir.mkdir(parents=True, exist_ok=True) - path = self.out_dir / "cfg.yaml" - with open(path, "w") as f: - self.dump_yaml(f) + # save_dc_types really ought to be the default... We simply can't load + # properly without this flag enabled. + save(self, self.out_dir / "cfg.yaml", save_dc_types=True) path = self.out_dir / "fingerprints.yaml" with open(path, "w") as meta_f: diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 24941f7e..579c2f52 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -3,6 +3,7 @@ import math from copy import deepcopy from dataclasses import dataclass, field +from pathlib import Path from typing import Literal, Optional, cast import torch @@ -59,7 +60,6 @@ class CcsReporterConfig(ReporterConfig): loss_dict: dict[str, float] = field(default_factory=dict, init=False) num_layers: int = 1 pre_ln: bool = False - seed: int = 42 supervised_weight: float = 0.0 lr: float = 1e-2 @@ -68,6 +68,10 @@ class CcsReporterConfig(ReporterConfig): optimizer: Literal["adam", "lbfgs"] = "lbfgs" weight_decay: float = 0.01 + @classmethod + def reporter_class(cls) -> type[Reporter]: + return CcsReporter + def __post_init__(self): self.loss_dict = parse_loss(self.loss) @@ -94,6 +98,11 @@ def __init__( ): super().__init__() self.config = cfg + self.in_features = in_features + + # Learnable Platt scaling parameters + self.bias = nn.Parameter(torch.zeros(1, device=device, dtype=dtype)) + self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype)) hidden_size = cfg.hidden_size or 4 * in_features // 3 @@ -239,7 +248,7 @@ def forward(self, x: Tensor) -> Tensor: def raw_forward(self, x: Tensor) -> Tensor: """Apply the probe to the provided input, without normalization.""" - return self.probe(x).squeeze(-1) + return self.probe(x).mul(self.scale).add(self.bias).squeeze(-1) def loss( self, @@ -401,3 +410,9 @@ def closure(): optimizer.step(closure) return float(loss) + + def save(self, path: Path | str) -> None: + """Save the reporter to a file.""" + state = {k: v.cpu() for k, v in self.state_dict().items()} + state.update(in_features=self.in_features) + torch.save(state, path) diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 028a8347..2e67ba6f 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -1,13 +1,12 @@ """An ELK reporter network.""" from dataclasses import dataclass -from typing import Optional +from pathlib import Path import torch -from einops import rearrange, repeat -from torch import Tensor, nn, optim +from einops import rearrange +from torch import Tensor, nn -from ..metrics import to_one_hot from ..truncated_eigh import truncated_eigh from ..utils.math_util import cov_mean_fused from .reporter import Reporter, ReporterConfig @@ -15,20 +14,23 @@ @dataclass class EigenReporterConfig(ReporterConfig): - """Configuration for an EigenReporter. - - Args: - var_weight: The weight of the variance term in the loss. - neg_cov_weight: The weight of the negative covariance term in the loss. - num_heads: The number of reporter heads to fit. In other words, the number - of eigenvectors to compute from the VINC matrix. - """ + """Configuration for an EigenReporter.""" var_weight: float = 0.0 + """The weight of the variance term in the loss.""" + neg_cov_weight: float = 0.5 + """The weight of the negative covariance term in the loss.""" num_heads: int = 1 + """The number of reporter heads to fit.""" + + save_reporter_stats: bool = False + """Whether to save the reporter statistics to disk in EigenReporter.save(). This + is useful for debugging and analysis, but can take up a lot of disk space.""" + use_centroids: bool = True + """Whether to average hiddens within each cluster before computing covariance.""" def __post_init__(self): if not (0 <= self.neg_cov_weight <= 1): @@ -36,6 +38,10 @@ def __post_init__(self): if self.num_heads <= 0: raise ValueError("num_heads must be positive") + @classmethod + def reporter_class(cls) -> type[Reporter]: + return EigenReporter + class EigenReporter(Reporter): """A linear reporter whose weights are computed via eigendecomposition. @@ -71,6 +77,7 @@ class EigenReporter(Reporter): intercluster_cov_M2: Tensor # variance intracluster_cov: Tensor # invariance contrastive_xcov_M2: Tensor # negative covariance + n: Tensor class_means: Tensor | None weight: Tensor @@ -79,20 +86,26 @@ def __init__( self, cfg: EigenReporterConfig, in_features: int, - num_classes: int | None = 2, + num_classes: int | None = None, *, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ): super().__init__() self.config = cfg + self.in_features = in_features + self.num_classes = num_classes # Learnable Platt scaling parameters self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype)) self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype)) # Running statistics - self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) + self.register_buffer( + "n", + torch.zeros((), device=device, dtype=torch.long), + persistent=cfg.save_reporter_stats, + ) self.register_buffer( "class_means", ( @@ -100,19 +113,23 @@ def __init__( if num_classes is not None else None ), + persistent=cfg.save_reporter_stats, ) self.register_buffer( "contrastive_xcov_M2", torch.zeros(in_features, in_features, device=device, dtype=dtype), + persistent=cfg.save_reporter_stats, ) self.register_buffer( "intercluster_cov_M2", torch.zeros(in_features, in_features, device=device, dtype=dtype), + persistent=cfg.save_reporter_stats, ) self.register_buffer( "intracluster_cov", torch.zeros(in_features, in_features, device=device, dtype=dtype), + persistent=cfg.save_reporter_stats, ) # Reporter weights @@ -128,10 +145,12 @@ def forward(self, hiddens: Tensor) -> Tensor: @property def contrastive_xcov(self) -> Tensor: + assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?" return self.contrastive_xcov_M2 / self.n @property def intercluster_cov(self) -> Tensor: + assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?" return self.intercluster_cov_M2 / self.n @property @@ -140,19 +159,13 @@ def confidence(self) -> Tensor: @property def invariance(self) -> Tensor: + assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?" return -self.weight @ self.intracluster_cov @ self.weight.mT @property def consistency(self) -> Tensor: return -self.weight @ self.contrastive_xcov @ self.weight.mT - def clear(self) -> None: - """Clear the running statistics of the reporter.""" - self.contrastive_xcov_M2.zero_() - self.intracluster_cov.zero_() - self.intercluster_cov_M2.zero_() - self.n.zero_() - @torch.no_grad() def update(self, hiddens: Tensor) -> None: (n, _, k, d) = hiddens.shape @@ -239,55 +252,26 @@ def fit_streaming(self, truncated: bool = False) -> float: self.weight.data = Q.T return -float(L[-1]) - def fit( - self, - hiddens: Tensor, - labels: Optional[Tensor] = None, - ) -> float: + def fit(self, hiddens: Tensor) -> float: """Fit the probe to the contrast set `hiddens`. Args: hiddens: The contrast set of shape [batch, variants, choices, dim]. - labels: The ground truth labels if available. Returns: loss: Negative eigenvalue associated with the VINC direction. """ self.update(hiddens) - loss = self.fit_streaming() - - if labels is not None: - (_, v, k, _) = hiddens.shape - hiddens = rearrange(hiddens, "n v k d -> (n v k) d") - labels = to_one_hot(repeat(labels, "n -> (n v)", v=v), k).flatten() - - self.platt_scale(labels, hiddens) - - return loss - - def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): - """Fit the scale and bias terms to data with LBFGS. - - Args: - labels: Binary labels of shape [batch]. - hiddens: Hidden states of shape [batch, dim]. - max_iter: Maximum number of iterations for LBFGS. - """ - opt = optim.LBFGS( - [self.bias, self.scale], - line_search_fn="strong_wolfe", - max_iter=max_iter, - tolerance_change=torch.finfo(hiddens.dtype).eps, - tolerance_grad=torch.finfo(hiddens.dtype).eps, + return self.fit_streaming() + + def save(self, path: Path | str) -> None: + """Save the reporter to a file.""" + # We basically never want to instantiate the reporter on the same device + # it happened to be trained on, so we save the state dict as CPU tensors. + # Bizarrely, this also seems to save a LOT of disk space in some cases. + state = {k: v.cpu() for k, v in self.state_dict().items()} + state.update( + in_features=self.in_features, + num_classes=self.num_classes, ) - - def closure(): - opt.zero_grad() - loss = nn.functional.binary_cross_entropy_with_logits( - self(hiddens), labels.float() - ) - - loss.backward() - return float(loss) - - opt.step(closure) + torch.save(state, path) diff --git a/elk/training/reporter.py b/elk/training/reporter.py index e6e84f96..1372d329 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -6,13 +6,13 @@ from typing import Optional import torch -import torch.nn as nn from simple_parsing.helpers import Serializable -from torch import Tensor +from simple_parsing.helpers.serialization import load +from torch import Tensor, nn, optim @dataclass -class ReporterConfig(Serializable): +class ReporterConfig(ABC, Serializable, decode_into_subclasses=True): """ Args: seed: The random seed to use. Defaults to 42. @@ -20,23 +20,22 @@ class ReporterConfig(Serializable): seed: int = 42 + @classmethod + @abstractmethod + def reporter_class(cls) -> type["Reporter"]: + """Get the reporter class associated with this config.""" + class Reporter(nn.Module, ABC): """An ELK reporter network.""" + # Learned Platt scaling parameters + bias: nn.Parameter + scale: nn.Parameter + def reset_parameters(self): """Reset the parameters of the probe.""" - # TODO: These methods will do something fancier in the future - @classmethod - def load(cls, path: Path | str): - """Load a reporter from a file.""" - return torch.load(path) - - def save(self, path: Path | str): - # TODO: Save separate JSON and PT files for the reporter. - torch.save(self, path) - @abstractmethod def fit( self, @@ -44,3 +43,57 @@ def fit( labels: Optional[Tensor] = None, ) -> float: ... + + @classmethod + def load(cls, path: Path | str, *, map_location: str = "cpu"): + """Load a reporter from a file.""" + obj = torch.load(path, map_location=map_location) + if isinstance(obj, Reporter): # Backwards compatibility + return obj + + # Loading a state dict rather than the full object + elif isinstance(obj, dict): + cls_path = Path(path).parent / "cfg.yaml" + cfg = load(ReporterConfig, cls_path) + + # Non-tensor values get passed to the constructor as kwargs + kwargs = {} + special_keys = {k for k, v in obj.items() if not isinstance(v, Tensor)} + for k in special_keys: + kwargs[k] = obj.pop(k) + + reporter_cls = cfg.reporter_class() + reporter = reporter_cls(cfg, device=map_location, **kwargs) + reporter.load_state_dict(obj) + return reporter + else: + raise TypeError( + f"Expected a `dict` or `Reporter` object, but got {type(obj)}." + ) + + def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): + """Fit the scale and bias terms to data with LBFGS. + + Args: + labels: Binary labels of shape [batch]. + hiddens: Hidden states of shape [batch, dim]. + max_iter: Maximum number of iterations for LBFGS. + """ + opt = optim.LBFGS( + [self.bias, self.scale], + line_search_fn="strong_wolfe", + max_iter=max_iter, + tolerance_change=torch.finfo(hiddens.dtype).eps, + tolerance_grad=torch.finfo(hiddens.dtype).eps, + ) + + def closure(): + opt.zero_grad() + loss = nn.functional.binary_cross_entropy_with_logits( + self(hiddens), labels.float() + ) + + loss.backward() + return float(loss) + + opt.step(closure) diff --git a/elk/training/train.py b/elk/training/train.py index d0ba8611..dcd978ca 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -9,6 +9,7 @@ import torch from einops import rearrange, repeat from simple_parsing import subgroups +from simple_parsing.helpers.serialization import save from ..metrics import evaluate_preds, to_one_hot from ..run import Run @@ -41,6 +42,10 @@ def create_models_dir(self, out_dir: Path): lr_dir.mkdir(parents=True, exist_ok=True) reporter_dir.mkdir(parents=True, exist_ok=True) + # Save the reporter config separately in the reporter directory + # for convenient loading of reporters later. + save(self.net, reporter_dir / "cfg.yaml", save_dc_types=True) + return reporter_dir, lr_dir def apply_to_layer( @@ -57,7 +62,7 @@ def apply_to_layer( train_dict = self.prepare_data(device, layer, "train") val_dict = self.prepare_data(device, layer, "val") - (first_train_h, train_labels, _), *rest = train_dict.values() + (first_train_h, train_gt, _), *rest = train_dict.values() d = first_train_h.shape[-1] if not all(other_h.shape[-1] == d for other_h, _, _ in rest): raise ValueError("All datasets must have the same hidden state size") @@ -67,7 +72,7 @@ def apply_to_layer( assert len(train_dict) == 1, "CCS only supports single-task training" reporter = CcsReporter(self.net, d, device=device) - train_loss = reporter.fit(first_train_h, train_labels) + train_loss = reporter.fit(first_train_h, train_gt) (val_h, val_gt, _) = next(iter(val_dict.values())) x0, x1 = first_train_h.unbind(2) @@ -77,6 +82,13 @@ def apply_to_layer( val_pair=(val_x0, val_x1), ) + # TODO: Enable Platt scaling for CCS once normalization is fixed + # (_, v, k, _) = first_train_h.shape + # reporter.platt_scale( + # to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), + # rearrange(first_train_h, "n v k d -> (n v k) d"), + # ) + elif isinstance(self.net, EigenReporterConfig): # We set num_classes to None to enable training on datasets with different # numbers of classes. Under the hood, this causes the covariance statistics @@ -84,14 +96,14 @@ def apply_to_layer( reporter = EigenReporter(self.net, d, num_classes=None, device=device) hidden_list, label_list = [], [] - for ds_name, (train_h, train_labels, _) in train_dict.items(): + for ds_name, (train_h, train_gt, _) in train_dict.items(): (_, v, k, _) = train_h.shape # Datasets can have different numbers of variants and different numbers # of classes, so we need to flatten them here before concatenating hidden_list.append(rearrange(train_h, "n v k d -> (n v k) d")) label_list.append( - to_one_hot(repeat(train_labels, "n -> (n v)", v=v), k).flatten() + to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten() ) reporter.update(train_h) @@ -105,8 +117,7 @@ def apply_to_layer( raise ValueError(f"Unknown reporter config type: {type(self.net)}") # Save reporter checkpoint to disk - with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: - torch.save(reporter, file) + reporter.save(reporter_dir / f"layer_{layer}.pt") # Fit supervised logistic regression model if self.supervised != "none": From 8ba18c3076873d6ca8d8a332bff7358042d4ce94 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Wed, 3 May 2023 01:31:13 -0700 Subject: [PATCH 6/6] Don't left truncate stuff anymore (#239) --- elk/extraction/extraction.py | 45 +++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 0d2f7bf4..2a4c36e2 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -3,7 +3,7 @@ import os from contextlib import nullcontext, redirect_stdout from dataclasses import InitVar, dataclass, replace -from itertools import islice, zip_longest +from itertools import zip_longest from typing import Any, Iterable, Literal from warnings import filterwarnings @@ -198,13 +198,25 @@ def extract_hiddens( layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1)) global_max_examples = cfg.max_examples[0 if split_type == "train" else 1] + # break `max_examples` among the processes roughly equally max_examples = global_max_examples // world_size + max_length = assert_type(int, tokenizer.model_max_length) + + # Keep track of the number of examples we've yielded so far. We can't do something + # clean like `islice` the dataset, because we skip examples that are too long, and + # we can't predict how many of those there will be. + num_yielded = 0 + # the last process gets the remainder (which is usually small) if rank == world_size - 1: max_examples += global_max_examples % world_size - for example in islice(prompt_ds, max_examples): + for example in prompt_ds: + # Check if we've yielded enough examples + if num_yielded >= max_examples: + break + num_variants = len(example["prompts"]) num_choices = len(example["prompts"][0]) @@ -236,19 +248,15 @@ def extract_hiddens( # Only feed question, not the answer, to the encoder for enc-dec models target = choice["answer"] if is_enc_dec else None - - # Record the EXACT question we fed to the model - variant_questions.append(text) encoding = tokenizer( text, # Keep [CLS] and [SEP] for BERT-style models add_special_tokens=True, return_tensors="pt", text_target=target, # type: ignore[arg-type] - truncation=True, ).to(device) - input_ids = assert_type(Tensor, encoding.input_ids) + input_ids = assert_type(Tensor, encoding.input_ids) if is_enc_dec: answer = assert_type(Tensor, encoding.labels) else: @@ -258,12 +266,16 @@ def extract_hiddens( add_special_tokens=False, return_tensors="pt", ).to(device) - answer = assert_type(Tensor, encoding2.input_ids) + answer = assert_type(Tensor, encoding2.input_ids) input_ids = torch.cat([input_ids, answer], dim=-1) - if max_len := tokenizer.model_max_length: - cur_len = input_ids.shape[-1] - input_ids = input_ids[..., -min(cur_len, max_len) :] + + # If this input is too long, skip it + if input_ids.shape[-1] > max_length: + break + else: + # Record the EXACT question we fed to the model + variant_questions.append(text) # Make sure we only pass the arguments that the model expects inputs = dict(input_ids=input_ids.long()) @@ -305,8 +317,18 @@ def extract_hiddens( for layer_idx, hidden in zip(layer_indices, hiddens): hidden_dict[f"hidden_{layer_idx}"][i, j] = float_to_int16(hidden) + # We skipped a pseudolabel because it was too long; break out of this whole + # example and move on to the next one + if len(variant_questions) != num_choices: + break + + # Usual case: we have the expected number of pseudolabels text_questions.append(variant_questions) + # We skipped a variant because it was too long; move on to the next example + if len(text_questions) != num_variants: + continue + out_record: dict[str, Any] = dict( label=example["label"], variant_ids=example["template_names"], @@ -316,6 +338,7 @@ def extract_hiddens( if has_lm_preds: out_record["model_logits"] = lm_logits + num_yielded += 1 yield out_record