Skip to content

Commit

Permalink
Modify lm similarity example notebook to work without cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
WarmCyan committed Sep 27, 2024
1 parent 7a34fcf commit b93c661
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions notebooks/lm_similarity_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,22 @@
"source": [
"# Language Model Similarity Example\n",
"\n",
"This notebook shows how to provide a language model to a similarity anchor, allowing the utilization of knowledge inside embedding spaces as part of the ICAT model."
"This notebook shows how to provide a language model to a similarity anchor, allowing the utilization of knowledge inside embedding spaces as part of the ICAT model.\n",
"\n",
"You will need to install the huggingface transformers and pytorch libraries for this notebook to run, please use\n",
"```\n",
"pip install transformers torch\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb6a33c6-e0f2-414f-9356-97f6fb47e2b9",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
Expand All @@ -21,7 +36,7 @@
"source": [
"# change these constants as needed based on your hardware constraints\n",
"BATCH_SIZE = 16\n",
"DEVICE = \"cuda\"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"MODEL_NAME = \"bert-base-uncased\""
]
},
Expand Down Expand Up @@ -175,7 +190,7 @@
"\n",
"dataset = fetch_20newsgroups(subset=\"train\")\n",
"df = pd.DataFrame({\"text\": dataset[\"data\"], \"category\": [dataset[\"target_names\"][i] for i in dataset[\"target\"]]})\n",
"#df = df.iloc[0:1999]\n",
"df = df.iloc[0:1999] # NOTE: if running on CPU or weaker GPU, recommend uncommenting this to avoid long processing times on first BERT anchor creation.\n",
"df.head()"
]
},
Expand All @@ -196,7 +211,7 @@
},
"outputs": [],
"source": [
"icat.initialize(offline=True)"
"icat.initialize(offline=False)"
]
},
{
Expand Down Expand Up @@ -279,7 +294,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down

0 comments on commit b93c661

Please sign in to comment.