-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #164 from kazewong/163-implement-optimization-stra…
…tegy 163 implement optimization strategy
- Loading branch information
Showing
5 changed files
with
493 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Customizing sampling strategy\n", | ||
"\n", | ||
"The default strategy of flowMC has two stages: Tuning the global sampler by training the normalizing flow, then freeze the normalizing flow to produce production level samples.\n", | ||
"But sometimes the user might want to add steps to this strategy or change things around. Since flowMC-0.3.1, we have refactored the internal API to make it easier to customize the sampling strategy.\n", | ||
"In this notebook, we will show an example to leverage extra steps in the sampling strategy." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"from jax.scipy.special import logsumexp\n", | ||
"from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline\n", | ||
"from flowMC.proposal.MALA import MALA\n", | ||
"from flowMC.Sampler import Sampler\n", | ||
"import corner\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"from flowMC.strategy.optimization import optimization_Adam\n", | ||
"\n", | ||
"n_dim = 5\n", | ||
"\n", | ||
"def target_dual_moon(x, data=None):\n", | ||
" \"\"\"\n", | ||
" Term 2 and 3 separate the distribution and smear it along the first and second dimension\n", | ||
" \"\"\"\n", | ||
" term1 = 0.5 * ((jnp.linalg.norm(x) - 2) / 0.1) ** 2\n", | ||
" term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2\n", | ||
" term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2\n", | ||
" return -(term1 - logsumexp(term2) - logsumexp(term3))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Let say our initialization is way off" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"n_chains = 20\n", | ||
"\n", | ||
"rng_key, subkey = jax.random.split(jax.random.PRNGKey(42))\n", | ||
"# Instead of initializing with a unit gaussian, we initialize with a gaussian with a larger variance\n", | ||
"initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 100\n", | ||
"\n", | ||
"n_dim = 5\n", | ||
"n_layers = 4\n", | ||
"hidden_size = [32, 32]\n", | ||
"num_bins = 8\n", | ||
"data = jnp.zeros(n_dim)\n", | ||
"rng_key, subkey = jax.random.split(rng_key)\n", | ||
"model = MaskedCouplingRQSpline(\n", | ||
" n_dim, n_layers, hidden_size, num_bins, subkey\n", | ||
")\n", | ||
"MALA_Sampler = MALA(target_dual_moon, True, step_size=0.1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"n_loop_training = 20\n", | ||
"n_loop_production = 20\n", | ||
"n_local_steps = 100\n", | ||
"n_global_steps = 10\n", | ||
"num_epochs = 5\n", | ||
"\n", | ||
"learning_rate = 0.005\n", | ||
"momentum = 0.9\n", | ||
"batch_size = 5000\n", | ||
"max_samples = 5000\n", | ||
"\n", | ||
"\n", | ||
"rng_key, subkey = jax.random.split(rng_key)\n", | ||
"nf_sampler = Sampler(\n", | ||
" n_dim,\n", | ||
" subkey,\n", | ||
" {'data': data},\n", | ||
" MALA_Sampler,\n", | ||
" model,\n", | ||
" n_loop_training=n_loop_training,\n", | ||
" n_loop_production=n_loop_production,\n", | ||
" n_local_steps=n_local_steps,\n", | ||
" n_global_steps=n_global_steps,\n", | ||
" n_chains=n_chains,\n", | ||
" n_epochs=num_epochs,\n", | ||
" learning_rate=learning_rate,\n", | ||
" momentum=momentum,\n", | ||
" batch_size=batch_size,\n", | ||
" use_global=True,\n", | ||
")\n", | ||
"print(nf_sampler.strategies)\n", | ||
"nf_sampler.sample(initial_position, data={'data':data})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We should see the chain need to started off really far from the most probable set. This is not a huge problem for this example since the posterior is rather simple and MALA uses gradient in its proposal. Still, one can see there is a huge jump in the NF loss at some point in the during, basically because the distribution the flow is approximating changes a lot." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"out_train = nf_sampler.get_sampler_state(training=True)\n", | ||
"chains = np.array(out_train[\"chains\"])\n", | ||
"global_accs = np.array(out_train[\"global_accs\"])\n", | ||
"local_accs = np.array(out_train[\"local_accs\"])\n", | ||
"loss_vals = np.array(out_train[\"loss_vals\"])\n", | ||
"rng_key, subkey = jax.random.split(rng_key)\n", | ||
"nf_samples = np.array(nf_sampler.sample_flow(subkey, 3000))\n", | ||
"\n", | ||
"\n", | ||
"# Plot 2 chains in the plane of 2 coordinates for first visual check\n", | ||
"plt.figure(figsize=(6, 6))\n", | ||
"axs = [plt.subplot(2, 2, i + 1) for i in range(4)]\n", | ||
"plt.sca(axs[0])\n", | ||
"plt.title(\"2d proj of 2 chains\")\n", | ||
"\n", | ||
"plt.plot(chains[0, :, 0], chains[0, :, 1], \"o-\", alpha=0.5, ms=2)\n", | ||
"plt.plot(chains[1, :, 0], chains[1, :, 1], \"o-\", alpha=0.5, ms=2)\n", | ||
"plt.xlabel(\"$x_1$\")\n", | ||
"plt.ylabel(\"$x_2$\")\n", | ||
"\n", | ||
"plt.sca(axs[1])\n", | ||
"plt.title(\"NF loss\")\n", | ||
"plt.plot(loss_vals.reshape(-1))\n", | ||
"plt.xlabel(\"iteration\")\n", | ||
"\n", | ||
"plt.sca(axs[2])\n", | ||
"plt.title(\"Local Acceptance\")\n", | ||
"plt.plot(local_accs.mean(0))\n", | ||
"plt.xlabel(\"iteration\")\n", | ||
"\n", | ||
"plt.sca(axs[3])\n", | ||
"plt.title(\"Global Acceptance\")\n", | ||
"plt.plot(global_accs.mean(0))\n", | ||
"plt.xlabel(\"iteration\")\n", | ||
"plt.tight_layout()\n", | ||
"plt.show(block=False)\n", | ||
"\n", | ||
"labels = [\"$x_1$\", \"$x_2$\", \"$x_3$\", \"$x_4$\", \"$x_5$\"]\n", | ||
"# Plot all chains\n", | ||
"figure = corner.corner(chains.reshape(-1, n_dim), labels=labels)\n", | ||
"figure.set_size_inches(7, 7)\n", | ||
"figure.suptitle(\"Visualize samples\")\n", | ||
"plt.show(block=False)\n", | ||
"\n", | ||
"# Plot Nf samples\n", | ||
"figure = corner.corner(nf_samples, labels=labels)\n", | ||
"figure.set_size_inches(7, 7)\n", | ||
"figure.suptitle(\"Visualize NF samples\")\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now let's try to run the same example but with an extra step in the sampling strategy: we will run Adam some number of steps before starting the normalizing flow training. This should help the normalizing flow to start closer to the target distribution." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"n_loop_training = 20\n", | ||
"n_loop_production = 20\n", | ||
"n_local_steps = 100\n", | ||
"n_global_steps = 10\n", | ||
"num_epochs = 5\n", | ||
"\n", | ||
"learning_rate = 0.005\n", | ||
"momentum = 0.9\n", | ||
"batch_size = 5000\n", | ||
"max_samples = 5000\n", | ||
"\n", | ||
"Adam_opt = optimization_Adam(n_steps=10000, learning_rate=1, noise_level= 1)\n", | ||
"\n", | ||
"rng_key, subkey = jax.random.split(rng_key)\n", | ||
"nf_sampler = Sampler(\n", | ||
" n_dim,\n", | ||
" subkey,\n", | ||
" {'data': data},\n", | ||
" MALA_Sampler,\n", | ||
" model,\n", | ||
" n_loop_training=n_loop_training,\n", | ||
" n_loop_production=n_loop_production,\n", | ||
" n_local_steps=n_local_steps,\n", | ||
" n_global_steps=n_global_steps,\n", | ||
" n_chains=n_chains,\n", | ||
" n_epochs=num_epochs,\n", | ||
" learning_rate=learning_rate,\n", | ||
" momentum=momentum,\n", | ||
" batch_size=batch_size,\n", | ||
" use_global=True,\n", | ||
" strategies=[Adam_opt, 'default'],\n", | ||
")\n", | ||
"print(nf_sampler.strategies)\n", | ||
"nf_sampler.sample(initial_position, data={'data':data})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"As we can see, the chains are much closer to the target distribution from the start, hence the normalizing flow training is much smoother." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"out_train = nf_sampler.get_sampler_state(training=True)\n", | ||
"chains = np.array(out_train[\"chains\"])\n", | ||
"global_accs = np.array(out_train[\"global_accs\"])\n", | ||
"local_accs = np.array(out_train[\"local_accs\"])\n", | ||
"loss_vals = np.array(out_train[\"loss_vals\"])\n", | ||
"rng_key, subkey = jax.random.split(rng_key)\n", | ||
"nf_samples = np.array(nf_sampler.sample_flow(subkey, 3000))\n", | ||
"\n", | ||
"\n", | ||
"# Plot 2 chains in the plane of 2 coordinates for first visual check\n", | ||
"plt.figure(figsize=(6, 6))\n", | ||
"axs = [plt.subplot(2, 2, i + 1) for i in range(4)]\n", | ||
"plt.sca(axs[0])\n", | ||
"plt.title(\"2d proj of 2 chains\")\n", | ||
"\n", | ||
"plt.plot(chains[0, :, 0], chains[0, :, 1], \"o-\", alpha=0.5, ms=2)\n", | ||
"plt.plot(chains[1, :, 0], chains[1, :, 1], \"o-\", alpha=0.5, ms=2)\n", | ||
"plt.xlabel(\"$x_1$\")\n", | ||
"plt.ylabel(\"$x_2$\")\n", | ||
"\n", | ||
"plt.sca(axs[1])\n", | ||
"plt.title(\"NF loss\")\n", | ||
"plt.plot(loss_vals.reshape(-1))\n", | ||
"plt.xlabel(\"iteration\")\n", | ||
"\n", | ||
"plt.sca(axs[2])\n", | ||
"plt.title(\"Local Acceptance\")\n", | ||
"plt.plot(local_accs.mean(0))\n", | ||
"plt.xlabel(\"iteration\")\n", | ||
"\n", | ||
"plt.sca(axs[3])\n", | ||
"plt.title(\"Global Acceptance\")\n", | ||
"plt.plot(global_accs.mean(0))\n", | ||
"plt.xlabel(\"iteration\")\n", | ||
"plt.tight_layout()\n", | ||
"plt.show(block=False)\n", | ||
"\n", | ||
"labels = [\"$x_1$\", \"$x_2$\", \"$x_3$\", \"$x_4$\", \"$x_5$\"]\n", | ||
"# Plot all chains\n", | ||
"figure = corner.corner(chains.reshape(-1, n_dim), labels=labels)\n", | ||
"figure.set_size_inches(7, 7)\n", | ||
"figure.suptitle(\"Visualize samples\")\n", | ||
"plt.show(block=False)\n", | ||
"\n", | ||
"# Plot Nf samples\n", | ||
"figure = corner.corner(nf_samples, labels=labels)\n", | ||
"figure.set_size_inches(7, 7)\n", | ||
"figure.suptitle(\"Visualize NF samples\")\n", | ||
"plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "jim", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.