Skip to content

Commit

Permalink
Merge pull request #164 from kazewong/163-implement-optimization-stra…
Browse files Browse the repository at this point in the history
…tegy

163 implement optimization strategy
  • Loading branch information
kazewong authored Apr 17, 2024
2 parents 845e2ca + 88e62d5 commit 0ec53fb
Show file tree
Hide file tree
Showing 5 changed files with 493 additions and 2 deletions.
315 changes: 315 additions & 0 deletions example/notebook/custom_strategy.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Customizing sampling strategy\n",
"\n",
"The default strategy of flowMC has two stages: Tuning the global sampler by training the normalizing flow, then freeze the normalizing flow to produce production level samples.\n",
"But sometimes the user might want to add steps to this strategy or change things around. Since flowMC-0.3.1, we have refactored the internal API to make it easier to customize the sampling strategy.\n",
"In this notebook, we will show an example to leverage extra steps in the sampling strategy."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from jax.scipy.special import logsumexp\n",
"from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline\n",
"from flowMC.proposal.MALA import MALA\n",
"from flowMC.Sampler import Sampler\n",
"import corner\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from flowMC.strategy.optimization import optimization_Adam\n",
"\n",
"n_dim = 5\n",
"\n",
"def target_dual_moon(x, data=None):\n",
" \"\"\"\n",
" Term 2 and 3 separate the distribution and smear it along the first and second dimension\n",
" \"\"\"\n",
" term1 = 0.5 * ((jnp.linalg.norm(x) - 2) / 0.1) ** 2\n",
" term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2\n",
" term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2\n",
" return -(term1 - logsumexp(term2) - logsumexp(term3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let say our initialization is way off"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_chains = 20\n",
"\n",
"rng_key, subkey = jax.random.split(jax.random.PRNGKey(42))\n",
"# Instead of initializing with a unit gaussian, we initialize with a gaussian with a larger variance\n",
"initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 100\n",
"\n",
"n_dim = 5\n",
"n_layers = 4\n",
"hidden_size = [32, 32]\n",
"num_bins = 8\n",
"data = jnp.zeros(n_dim)\n",
"rng_key, subkey = jax.random.split(rng_key)\n",
"model = MaskedCouplingRQSpline(\n",
" n_dim, n_layers, hidden_size, num_bins, subkey\n",
")\n",
"MALA_Sampler = MALA(target_dual_moon, True, step_size=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"n_loop_training = 20\n",
"n_loop_production = 20\n",
"n_local_steps = 100\n",
"n_global_steps = 10\n",
"num_epochs = 5\n",
"\n",
"learning_rate = 0.005\n",
"momentum = 0.9\n",
"batch_size = 5000\n",
"max_samples = 5000\n",
"\n",
"\n",
"rng_key, subkey = jax.random.split(rng_key)\n",
"nf_sampler = Sampler(\n",
" n_dim,\n",
" subkey,\n",
" {'data': data},\n",
" MALA_Sampler,\n",
" model,\n",
" n_loop_training=n_loop_training,\n",
" n_loop_production=n_loop_production,\n",
" n_local_steps=n_local_steps,\n",
" n_global_steps=n_global_steps,\n",
" n_chains=n_chains,\n",
" n_epochs=num_epochs,\n",
" learning_rate=learning_rate,\n",
" momentum=momentum,\n",
" batch_size=batch_size,\n",
" use_global=True,\n",
")\n",
"print(nf_sampler.strategies)\n",
"nf_sampler.sample(initial_position, data={'data':data})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We should see the chain need to started off really far from the most probable set. This is not a huge problem for this example since the posterior is rather simple and MALA uses gradient in its proposal. Still, one can see there is a huge jump in the NF loss at some point in the during, basically because the distribution the flow is approximating changes a lot."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"out_train = nf_sampler.get_sampler_state(training=True)\n",
"chains = np.array(out_train[\"chains\"])\n",
"global_accs = np.array(out_train[\"global_accs\"])\n",
"local_accs = np.array(out_train[\"local_accs\"])\n",
"loss_vals = np.array(out_train[\"loss_vals\"])\n",
"rng_key, subkey = jax.random.split(rng_key)\n",
"nf_samples = np.array(nf_sampler.sample_flow(subkey, 3000))\n",
"\n",
"\n",
"# Plot 2 chains in the plane of 2 coordinates for first visual check\n",
"plt.figure(figsize=(6, 6))\n",
"axs = [plt.subplot(2, 2, i + 1) for i in range(4)]\n",
"plt.sca(axs[0])\n",
"plt.title(\"2d proj of 2 chains\")\n",
"\n",
"plt.plot(chains[0, :, 0], chains[0, :, 1], \"o-\", alpha=0.5, ms=2)\n",
"plt.plot(chains[1, :, 0], chains[1, :, 1], \"o-\", alpha=0.5, ms=2)\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$x_2$\")\n",
"\n",
"plt.sca(axs[1])\n",
"plt.title(\"NF loss\")\n",
"plt.plot(loss_vals.reshape(-1))\n",
"plt.xlabel(\"iteration\")\n",
"\n",
"plt.sca(axs[2])\n",
"plt.title(\"Local Acceptance\")\n",
"plt.plot(local_accs.mean(0))\n",
"plt.xlabel(\"iteration\")\n",
"\n",
"plt.sca(axs[3])\n",
"plt.title(\"Global Acceptance\")\n",
"plt.plot(global_accs.mean(0))\n",
"plt.xlabel(\"iteration\")\n",
"plt.tight_layout()\n",
"plt.show(block=False)\n",
"\n",
"labels = [\"$x_1$\", \"$x_2$\", \"$x_3$\", \"$x_4$\", \"$x_5$\"]\n",
"# Plot all chains\n",
"figure = corner.corner(chains.reshape(-1, n_dim), labels=labels)\n",
"figure.set_size_inches(7, 7)\n",
"figure.suptitle(\"Visualize samples\")\n",
"plt.show(block=False)\n",
"\n",
"# Plot Nf samples\n",
"figure = corner.corner(nf_samples, labels=labels)\n",
"figure.set_size_inches(7, 7)\n",
"figure.suptitle(\"Visualize NF samples\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's try to run the same example but with an extra step in the sampling strategy: we will run Adam some number of steps before starting the normalizing flow training. This should help the normalizing flow to start closer to the target distribution."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"n_loop_training = 20\n",
"n_loop_production = 20\n",
"n_local_steps = 100\n",
"n_global_steps = 10\n",
"num_epochs = 5\n",
"\n",
"learning_rate = 0.005\n",
"momentum = 0.9\n",
"batch_size = 5000\n",
"max_samples = 5000\n",
"\n",
"Adam_opt = optimization_Adam(n_steps=10000, learning_rate=1, noise_level= 1)\n",
"\n",
"rng_key, subkey = jax.random.split(rng_key)\n",
"nf_sampler = Sampler(\n",
" n_dim,\n",
" subkey,\n",
" {'data': data},\n",
" MALA_Sampler,\n",
" model,\n",
" n_loop_training=n_loop_training,\n",
" n_loop_production=n_loop_production,\n",
" n_local_steps=n_local_steps,\n",
" n_global_steps=n_global_steps,\n",
" n_chains=n_chains,\n",
" n_epochs=num_epochs,\n",
" learning_rate=learning_rate,\n",
" momentum=momentum,\n",
" batch_size=batch_size,\n",
" use_global=True,\n",
" strategies=[Adam_opt, 'default'],\n",
")\n",
"print(nf_sampler.strategies)\n",
"nf_sampler.sample(initial_position, data={'data':data})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the chains are much closer to the target distribution from the start, hence the normalizing flow training is much smoother."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"out_train = nf_sampler.get_sampler_state(training=True)\n",
"chains = np.array(out_train[\"chains\"])\n",
"global_accs = np.array(out_train[\"global_accs\"])\n",
"local_accs = np.array(out_train[\"local_accs\"])\n",
"loss_vals = np.array(out_train[\"loss_vals\"])\n",
"rng_key, subkey = jax.random.split(rng_key)\n",
"nf_samples = np.array(nf_sampler.sample_flow(subkey, 3000))\n",
"\n",
"\n",
"# Plot 2 chains in the plane of 2 coordinates for first visual check\n",
"plt.figure(figsize=(6, 6))\n",
"axs = [plt.subplot(2, 2, i + 1) for i in range(4)]\n",
"plt.sca(axs[0])\n",
"plt.title(\"2d proj of 2 chains\")\n",
"\n",
"plt.plot(chains[0, :, 0], chains[0, :, 1], \"o-\", alpha=0.5, ms=2)\n",
"plt.plot(chains[1, :, 0], chains[1, :, 1], \"o-\", alpha=0.5, ms=2)\n",
"plt.xlabel(\"$x_1$\")\n",
"plt.ylabel(\"$x_2$\")\n",
"\n",
"plt.sca(axs[1])\n",
"plt.title(\"NF loss\")\n",
"plt.plot(loss_vals.reshape(-1))\n",
"plt.xlabel(\"iteration\")\n",
"\n",
"plt.sca(axs[2])\n",
"plt.title(\"Local Acceptance\")\n",
"plt.plot(local_accs.mean(0))\n",
"plt.xlabel(\"iteration\")\n",
"\n",
"plt.sca(axs[3])\n",
"plt.title(\"Global Acceptance\")\n",
"plt.plot(global_accs.mean(0))\n",
"plt.xlabel(\"iteration\")\n",
"plt.tight_layout()\n",
"plt.show(block=False)\n",
"\n",
"labels = [\"$x_1$\", \"$x_2$\", \"$x_3$\", \"$x_4$\", \"$x_5$\"]\n",
"# Plot all chains\n",
"figure = corner.corner(chains.reshape(-1, n_dim), labels=labels)\n",
"figure.set_size_inches(7, 7)\n",
"figure.suptitle(\"Visualize samples\")\n",
"plt.show(block=False)\n",
"\n",
"# Plot Nf samples\n",
"figure = corner.corner(nf_samples, labels=labels)\n",
"figure.set_size_inches(7, 7)\n",
"figure.suptitle(\"Visualize NF samples\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
12 changes: 10 additions & 2 deletions src/flowMC/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
)
self.optim_state = self.optim.init(eqx.filter(self.nf_model, eqx.is_array))

self.strategies = [
default_strategies = [
GlobalTuning(
n_dim=self.n_dim,
n_chains=self.n_chains,
Expand Down Expand Up @@ -161,7 +161,15 @@ def __init__(
if kwargs.get("strategies") is not None:
kwargs_strategies = kwargs.get("strategies")
assert isinstance(kwargs_strategies, list)
self.strategies = kwargs_strategies
self.strategies = []
for strategy in kwargs_strategies:
if isinstance(strategy, str):
if strategy.lower() == "default":
self.strategies += default_strategies
else:
self.strategies.append(strategy)
else:
self.strategies = default_strategies

self.summary = {}

Expand Down
1 change: 1 addition & 0 deletions src/flowMC/nfmodel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def train(
pbar = range(num_epochs)

best_model = model = self
best_state = state
best_loss = 1e9
for epoch in pbar:
# Use a separate PRNG key to permute image data during shuffling
Expand Down
Loading

0 comments on commit 0ec53fb

Please sign in to comment.