Skip to content

Commit

Permalink
save logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Sep 21, 2023
1 parent 5c1f656 commit 1730303
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 88 deletions.
29 changes: 24 additions & 5 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from simple_parsing.helpers import field

from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..metrics import evaluate_preds, get_logprobs
from ..run import Run
from ..utils import Color

Expand All @@ -30,7 +30,7 @@ def execute(self, highlight_color: Color = "cyan"):
@torch.inference_mode()
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
) -> tuple[dict[str, pd.DataFrame], dict]:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")
Expand All @@ -43,19 +43,38 @@ def apply_to_layer(
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

out_logprobs = defaultdict(dict)
row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt) in val_output.items():
for ds_name, val_data in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

if self.save_logprobs:
out_logprobs[ds_name] = dict(
row_ids=val_data.row_ids,
variant_ids=val_data.variant_ids,
texts=val_data.texts,
labels=val_data.labels,
lm=dict(),
lr=dict(),
)
for mode in ("none", "full"):
# TODO save lm logprobs and add to buf
for i, model in enumerate(lr_models):
model.eval()
val_credences = model(val_data.hiddens)
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = get_logprobs(
val_credences, mode
).cpu()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
**evaluate_preds(
val_data.labels, val_credences, mode
).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs
1 change: 1 addition & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def select_hiddens(outputs: Any) -> dict:
if len(buffer[row_id]) == num_variants:
# we have a complete example
ex = buffer[row_id]
ex = sorted(ex, key=lambda d: d["variant_id"])
assert all(d["label"] == ex[0]["label"] for d in ex)
assert len(set(d["variant_id"] for d in ex)) == num_variants
out_record = dict(
Expand Down
3 changes: 2 additions & 1 deletion elk/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .accuracy import accuracy_ci
from .calibration import CalibrationError, CalibrationEstimate
from .eval import EvalResult, evaluate_preds
from .eval import EvalResult, evaluate_preds, get_logprobs
from .roc_auc import RocAucResult, roc_auc, roc_auc_ci

__all__ = [
Expand All @@ -9,6 +9,7 @@
"CalibrationEstimate",
"EvalResult",
"evaluate_preds",
"get_logprobs",
"roc_auc",
"roc_auc_ci",
"RocAucResult",
Expand Down
17 changes: 17 additions & 0 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal

import torch
import torch.nn.functional as F
from einops import repeat
from torch import Tensor

Expand Down Expand Up @@ -41,6 +42,22 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict}


def get_logprobs(
y_logits: Tensor, ensembling: Literal["none", "full"] = "none"
) -> Tensor:
"""
Get the class probabilities from a tensor of logits.
Args:
y_logits: Predicted log-odds of the positive class, tensor of shape (n, v).
Returns:
Tensor of logprobs: If ensemble is "none", a tensor of shape (n, v).
If ensemble is "full", a tensor of shape (n,).
"""
if ensembling == "full":
y_logits = y_logits.mean(dim=1)
return F.logsigmoid(y_logits)


def evaluate_preds(
y_true: Tensor,
y_logits: Tensor,
Expand Down
61 changes: 55 additions & 6 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@
)


@dataclass
class LayerData:
hiddens: Tensor
labels: Tensor
lm_preds: Tensor | None
texts: list[list[str]] # (n, v)
row_ids: list[int] # (n,)
variant_ids: list[list[str]] # (n, v)


@dataclass
class Run(ABC, Serializable):
data: Extract
Expand All @@ -46,6 +56,15 @@ class Run(ABC, Serializable):
prompt_indices: tuple[int, ...] = ()
"""The indices of the prompt templates to use. If empty, all prompts are used."""

save_logprobs: bool = field(default=False, to_dict=False)
""" saves logprobs.pt containing
{<dsname>: {"row_ids": [n,], "variant_ids": [n, v],
"labels": [n,], "texts": [n, v],
"lm": {"none": [n, v], "full": [n,]},
"lr": {<layer>: {<inlp_iter>: {"none": [n, v], "full": [n,]}}}
}}
"""

concatenated_layer_offset: int = 0
debug: bool = False
num_gpus: int = -1
Expand Down Expand Up @@ -96,15 +115,15 @@ def execute(

devices = select_usable_devices(self.num_gpus)
num_devices = len(devices)
func: Callable[[int], dict[str, pd.DataFrame]] = partial(
func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]] = partial(
self.apply_to_layer, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)

@abstractmethod
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
) -> tuple[dict[str, pd.DataFrame], dict]:
"""Train or eval a reporter on a single layer."""

def make_reproducible(self, seed: int):
Expand All @@ -123,7 +142,7 @@ def get_device(self, devices, world_size: int) -> str:

def prepare_data(
self, device: str, layer: int, split_type: Literal["train", "val"]
) -> dict[str, tuple[Tensor, Tensor]]:
) -> dict[str, LayerData]:
"""Prepare data for the specified layer and split type."""
out = {}

Expand All @@ -137,7 +156,14 @@ def prepare_data(
if self.prompt_indices:
hiddens = hiddens[:, self.prompt_indices]

out[ds_name] = (hiddens, labels.to(hiddens.device))
out[ds_name] = LayerData(
hiddens=hiddens,
labels=labels,
lm_preds=None, # TODO: implement
texts=split["texts"],
row_ids=split["row_id"],
variant_ids=split["variant_ids"],
)

return out

Expand All @@ -150,7 +176,7 @@ def concatenate(self, layers):

def apply_to_layers(
self,
func: Callable[[int], dict[str, pd.DataFrame]],
func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]],
num_devices: int,
):
"""Apply a function to each layer of the datasets in parallel
Expand All @@ -173,15 +199,38 @@ def apply_to_layers(
with ctx.Pool(num_devices) as pool:
mapper = pool.imap_unordered if num_devices > 1 else map
df_buffers = defaultdict(list)
logprobs_dicts = defaultdict(dict)

try:
for df_dict in tqdm(mapper(func, layers), total=len(layers)):
for layer, (df_dict, logprobs_dict) in tqdm(
zip(layers, mapper(func, layers)), total=len(layers)
):
for k, v in df_dict.items():
df_buffers[k].append(v)
for k, v in logprobs_dict.items():
logprobs_dicts[k][layer] = logprobs_dict[k]
finally:
# Make sure the CSVs are written even if we crash or get interrupted
for name, dfs in df_buffers.items():
df = pd.concat(dfs).sort_values(by=["layer", "ensembling"])
df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False)
if self.debug:
save_debug_log(self.datasets, self.out_dir)
if self.save_logprobs:
save_dict = defaultdict(dict)
for ds_name, logprobs_dict in logprobs_dicts.items():
save_dict[ds_name]["texts"] = logprobs_dict[layers[0]]["texts"]
save_dict[ds_name]["labels"] = logprobs_dict[layers[0]][
"labels"
]
save_dict[ds_name]["lm"] = logprobs_dict[layers[0]]["lm"]
save_dict[ds_name]["reporter"] = dict()
save_dict[ds_name]["lr"] = dict()
for layer, logprobs_dict_by_mode in logprobs_dict.items():
save_dict[ds_name]["reporter"][
layer
] = logprobs_dict_by_mode["reporter"]
save_dict[ds_name]["lr"][layer] = logprobs_dict_by_mode[
"lr"
]
torch.save(save_dict, self.out_dir / "logprobs.pt")
2 changes: 1 addition & 1 deletion elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def fit(
x: Tensor,
y: Tensor,
*,
l2_penalty: float = 0.0,
l2_penalty: float = 0.001,
max_iter: int = 10_000,
) -> float:
"""Fits the model to the input data using L-BFGS with L2 regularization.
Expand Down
11 changes: 6 additions & 5 deletions elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from concept_erasure import LeaceFitter
from einops import rearrange, repeat

from ..run import LayerData
from .classifier import Classifier


def train_supervised(
data: dict[str, tuple], device: str, mode: str, erase_paraphrases: bool = False
data: dict[str, LayerData], device: str, mode: str, erase_paraphrases: bool = False
) -> list[Classifier]:
assert not (
erase_paraphrases and len(data) > 1
Expand All @@ -15,9 +16,9 @@ def train_supervised(

leace = None

for train_h, labels in data.values():
(n, v, d) = train_h.shape
train_h = rearrange(train_h, "n v d -> (n v) d")
for train_data in data.values():
(n, v, d) = train_data.hiddens.shape
train_h = rearrange(train_data.hiddens, "n v d -> (n v) d")

if erase_paraphrases:
if leace is None:
Expand All @@ -33,7 +34,7 @@ def train_supervised(
) # (n * v, v)
leace = leace.update(train_h, indicators)

labels = repeat(labels, "n -> (n v)", v=v)
labels = repeat(train_data.labels, "n -> (n v)", v=v)

Xs.append(train_h)
train_labels.append(labels)
Expand Down
Loading

0 comments on commit 1730303

Please sign in to comment.