Skip to content

Commit

Permalink
Add sine notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Dec 4, 2023
1 parent aa61de5 commit dc2415d
Showing 1 changed file with 155 additions and 2 deletions.
157 changes: 155 additions & 2 deletions notebooks/sine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,146 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sine Example"
"# Sine Example\n",
"\n",
"This example shows how to use a neural operator to learn the sine function in\n",
"a self-supervised manner."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"Import modules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from continuity.data.sine import SineWaves\n",
"from continuity.plotting.plotting import *\n",
"from continuity.model.neuraloperator import NeuralOperator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data set\n",
"\n",
"Create a data set of sine waves with 32 sensors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"size = 8\n",
"dataset = SineWaves(\n",
" num_sensors=32,\n",
" size=size,\n",
" batch_size=size,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neural operator\n",
"\n",
"Create a neural operator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = NeuralOperator(\n",
" coordinate_dim=dataset.coordinate_dim,\n",
" num_channels=dataset.num_channels,\n",
" depth=2,\n",
" kernel_width=32,\n",
" kernel_depth=1,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"Train the neural operator in a self-supervised manner."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
"criterion = torch.nn.MSELoss()\n",
"\n",
"model.compile(optimizer, criterion)\n",
"model.fit(dataset, epochs=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting\n",
"\n",
"Plot training data and predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(size // 4, 4, figsize=(16, 3 * size // 4))\n",
"if size // 4 == 1:\n",
" axs = [axs]\n",
"\n",
"for i in range(size):\n",
" ax = axs[i // 4][i % 4]\n",
" obs = dataset.get_observation(i)\n",
" plot_evaluation(model, dataset, obs, ax=ax)\n",
" plot_observation(obs, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot test observations and predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
"obs = dataset.generate_observation((size-1) / 2)\n",
"plot_evaluation(model, dataset, obs, ax=ax)\n",
"plot_observation(obs, ax=ax)"
]
},
{
Expand All @@ -16,8 +155,22 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down

0 comments on commit dc2415d

Please sign in to comment.