diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index a855e8d8..cfc422d4 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -50,10 +50,10 @@ def apply_to_layer( if self.save_logprobs: out_logprobs[ds_name] = dict( - row_ids=val_data.row_ids, + row_ids=val_data.row_ids.cpu(), variant_ids=val_data.variant_ids, texts=val_data.texts, - labels=val_data.labels, + labels=val_data.labels.cpu(), lm=dict(), lr=dict(), ) diff --git a/elk/run.py b/elk/run.py index 76bdd598..cdcf56bf 100644 --- a/elk/run.py +++ b/elk/run.py @@ -37,7 +37,7 @@ class LayerData: labels: Tensor lm_log_odds: Tensor | None texts: list[list[str]] # (n, v) - row_ids: list[int] # (n,) + row_ids: Tensor # (n,) variant_ids: list[list[str]] # (n, v) @@ -167,7 +167,7 @@ def prepare_data( labels=labels, lm_log_odds=lm_preds, texts=split["texts"], - row_ids=split["row_id"], + row_ids=assert_type(Tensor, split["row_id"]), variant_ids=split["variant_ids"], ) @@ -208,9 +208,12 @@ def apply_to_layers( logprobs_dicts = defaultdict(dict) try: - for layer, (df_dict, logprobs_dict) in tqdm( - zip(layers, mapper(func, layers)), total=len(layers) + for df_dict, logprobs_dict in tqdm( + mapper(func, layers), total=len(layers) ): + # get arbitrary value + df_ = next(iter(df_dict.values())) + layer = df_["layer"].iloc[0] for k, v in df_dict.items(): df_buffers[k].append(v) for k, v in logprobs_dict.items(): diff --git a/elk/training/train.py b/elk/training/train.py index 6129d9c9..baa21991 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -90,10 +90,10 @@ def apply_to_layer( if self.save_logprobs: out_logprobs[ds_name] = dict( - row_ids=val.row_ids, + row_ids=val.row_ids.cpu(), variant_ids=val.variant_ids, texts=val.texts, - labels=val.labels, + labels=val.labels.cpu(), lm=dict(), lr=dict(), )