Skip to content

Commit

Permalink
Fix changing and training dataset issues with lm example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
WarmCyan committed Aug 30, 2023
1 parent bd76663 commit fc5ca70
Showing 1 changed file with 46 additions and 68 deletions.
114 changes: 46 additions & 68 deletions notebooks/lm_similarity_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,51 +80,59 @@
" DESCRIPTION = \"Use cosine similarity between BERT embeddings of data and target.\"\n",
" # optional attribute to set description of anchor type in ICAT UI\n",
" \n",
" def embed(self, data: pd.DataFrame) -> pd.DataFrame:\n",
" \"\"\"This function takes some set of data and embeds the text column using\n",
" the transformer model stored in ``text_model``.\"\"\"\n",
" embedded_batches = []\n",
" \n",
" # run the tokenizer and model embedding on batches\n",
" max_batches = data.shape[0] // BATCH_SIZE + 1\n",
" last_batch = data.shape[0] // BATCH_SIZE\n",
" for batch in range(max_batches):\n",
" # compute range for this batch\n",
" batch_start = batch * BATCH_SIZE\n",
" batch_end = data.shape[0] if batch == last_batch else batch_start + BATCH_SIZE\n",
"\n",
" # get the texts within the batch range\n",
" batch_text = data[self.text_col].tolist()[batch_start:batch_end]\n",
"\n",
" # tokenize and embed with the model\n",
" tokenized = tokenizer(\n",
" batch_text, \n",
" return_tensors='pt', \n",
" truncation=True, \n",
" padding=\"max_length\",\n",
" )[\"input_ids\"].to(DEVICE).detach()\n",
" text_embeddings = text_model(tokenized).last_hidden_state.detach().cpu().numpy()\n",
" embedded_batches.append(text_embeddings)\n",
" \n",
" # stack all the embeddings and average the token embeddings to get the full text \n",
" # representation for each\n",
" embeddings = np.concatenate(embedded_batches, axis=0)\n",
" embeddings = embeddings.mean(axis=1)\n",
" embeddings_df = pd.DataFrame(embeddings, index=data.index)\n",
" return embeddings_df\n",
"\n",
" def featurize(self, data: pd.DataFrame) -> pd.Series:\n",
" target_text = self.reference_texts[0] # the target text we're computing similarity to.\n",
" # Note that for simplicity we only use the first\n",
" # referenced text, but in principle this function\n",
" # could be implemented to handle multiple targets,\n",
" # e.g. use the average embedding.\n",
" source_text = data[self.text_col].tolist()\n",
" \n",
" # if we haven't computed the embeddings for the dataframe yet, do so now.\n",
" # NOTE: this works on the assumption that the dataset isn't going to change,\n",
" # special considerations are required if we put new data through this function\n",
" # determine data that hasn't been embedded yet, note that we determine this exclusively \n",
" # by index\n",
" to_embed = data\n",
" cache_key = f\"similarity_embeddings_{MODEL_NAME}\"\n",
" if cache_key not in self.global_cache:\n",
" # we check for/save any embeddings in the global cache, which is stored on the\n",
" # anchorlist and accessible to all anchors (meaning any other anchors of this type\n",
" # can all access the same set of embeddings/only need to compute them once.\n",
" embedded_batches = []\n",
" \n",
" # run the tokenizer and model embedding on batches\n",
" max_batches = data.shape[0] // BATCH_SIZE + 1\n",
" last_batch = data.shape[0] // BATCH_SIZE\n",
" for batch in range(max_batches):\n",
" # compute range for this batch\n",
" batch_start = batch * BATCH_SIZE\n",
" batch_end = data.shape[0] if batch == last_batch else batch_start + BATCH_SIZE\n",
"\n",
" # get the texts within the batch range\n",
" batch_text = source_text[batch_start:batch_end]\n",
"\n",
" # tokenize and embed with the model\n",
" tokenized = tokenizer(\n",
" batch_text, \n",
" return_tensors='pt', \n",
" truncation=True, \n",
" padding=\"max_length\",\n",
" )[\"input_ids\"].to(DEVICE).detach()\n",
" text_embeddings = text_model(tokenized).last_hidden_state.detach().cpu().numpy()\n",
" embedded_batches.append(text_embeddings)\n",
" \n",
" # stack all the embeddings and average the token embeddings to get the full text \n",
" # representation for each\n",
" embeddings = np.concatenate(embedded_batches, axis=0)\n",
" embeddings = embeddings.mean(axis=1)\n",
"\n",
" self.global_cache[cache_key] = embeddings\n",
" if cache_key in self.global_cache:\n",
" to_embed = data[~data.index.isin(self.global_cache[cache_key].index)]\n",
" else:\n",
" # make sure the series exists to place our embeddings into later\n",
" self.global_cache[cache_key] = pd.DataFrame()\n",
" \n",
" # perform any necessary embeddings and store into global cache.\n",
" if len(to_embed) > 0:\n",
" self.global_cache[cache_key] = pd.concat([self.global_cache[cache_key], self.embed(to_embed)])\n",
" \n",
" # tokenize and get the full text embedding for the target text\n",
" tokenized_target = tokenizer(\n",
Expand All @@ -138,7 +146,7 @@
"\n",
" # compute cosine similarity between the target text embedding and all the embeddings\n",
" # from the dataframe\n",
" similarities = cosine_similarity(target_embedding, self.global_cache[cache_key])\n",
" similarities = cosine_similarity(target_embedding, self.global_cache[cache_key].loc[data.index].values)\n",
"\n",
" # massage the similarity values a little to get better spread in the visualization \n",
" # and put a minimum threshold on \"activation\"\n",
Expand Down Expand Up @@ -224,36 +232,6 @@
"source": [
"model.view"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30d6bfd1-5632-4038-85f6-89e76231eb27",
"metadata": {},
"outputs": [],
"source": [
"model.anchor_list.save(\"wip\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff5e1561-aa20-4783-b2b2-1beccce8bf25",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.anchor_list.load(\"wip\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3ee1b7c-1dae-4ff5-9985-b677944499e1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit fc5ca70

Please sign in to comment.