Skip to content

Commit

Permalink
hyperparams for ex
Browse files Browse the repository at this point in the history
  • Loading branch information
yallup committed Jul 10, 2024
1 parent 83a2515 commit db0e2f7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
8 changes: 3 additions & 5 deletions clax/clax.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,9 @@ def _init_state(self, **kwargs):
if not optimizer:
optimizer = optax.chain(
# optax.clip_by_global_norm(1.0),
# optax.adaptive_grad_clip(0.1),
optax.adaptive_grad_clip(1.0),
# optax.adam(lr),
# optax.adamw(self.schedule),
optax.adamw(lr),
optax.adaptive_grad_clip(0.01),
optax.adamw(self.schedule),
# optax.adamw(lr),
)

# self.state = train_state.TrainState.create(
Expand Down
1 change: 0 additions & 1 deletion clax/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Network(nn.Module):
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(self.n_initial)(x)
# nn.BatchNorm(use_running_average=not train)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.silu(x)
for i in range(self.n_layers):
Expand Down
57 changes: 38 additions & 19 deletions examples/bayes_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@

import matplotlib.pyplot as plt
import numpy as np
import optax
from flax import linen as nn
from mpl_toolkits.axes_grid1.inset_locator import mark_inset, zoomed_inset_axes
from scipy.stats import multivariate_normal
from sklearn.datasets import make_sparse_spd_matrix
from sklearn.model_selection import train_test_split

from clax import Classifier

# from clax.network import Network

np.random.seed(2025)
np.random.seed(2024)
dim = 100
n_sample = 100000
n_sample = 500000


c1 = np.random.rand(dim) - 0.5
Expand All @@ -31,29 +30,49 @@
midpoint = (m1 + m2) / 2
error = 0.025

C1 = make_sparse_spd_matrix(dim, norm_diag=True, smallest_coef=0.01, largest_coef=0.25)
C2 = make_sparse_spd_matrix(dim, norm_diag=True, smallest_coef=0.01, largest_coef=0.25)

M_0 = multivariate_normal(mean=m1, cov=np.eye(dim) * error)
M_1 = multivariate_normal(mean=m2, cov=np.eye(dim) * error)
M_2 = multivariate_normal(mean=midpoint, cov=np.eye(dim) * error)


X = np.concatenate((M_0.rvs(n_sample), M_1.rvs(n_sample)))
y = np.concatenate((np.zeros(n_sample), np.ones(n_sample)))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.01)


# # Arg is the number classes
classifier = Classifier()

chain = optax.chain(
optax.adaptive_grad_clip(1.0),
optax.adamw(1e-3),
)
# optionally specify the optimizer manually
# chain = optax.chain(
# optax.adaptive_grad_clip(1.0),
# optax.adamw(1e-3),
# )


class Network(nn.Module):
"""A simple MLP classifier."""

classifier.fit(X_train, y_train, epochs=500, optimizer=chain, batch_size=1000)
n_initial: int = 256
n_hidden: int = 64
n_layers: int = 3
n_out: int = 1
# act = nn.silu

@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(self.n_initial)(x)
# hacky way to make batchnorm have no impact
nn.BatchNorm(use_running_average=not train)(x)
x = nn.silu(x)
for i in range(self.n_layers):
x = nn.Dense(self.n_hidden)(x)
x = nn.silu(x)
x = nn.Dense(self.n_out)(x)
return x


lr = 1e-4
classifier.network = Network(n_out=1, n_initial=1056, n_hidden=128, n_layers=3)
classifier.fit(X_train, y_train, epochs=100, lr=lr, batch_size=10000)

true_k = M_1.logpdf(X_test) - M_0.logpdf(X_test)
network_k = classifier.predict(X_test).squeeze()
Expand All @@ -65,12 +84,12 @@


def plot():
f, a = plt.subplots(1, 1)
f, a = plt.subplots(1, 1, figsize=(6, 4))
a.scatter(
true_k_m2,
network_k_m2,
alpha=0.5,
c="C4",
c="C1",
label=r"$M_2$ test",
marker=".",
rasterized=True,
Expand Down Expand Up @@ -172,12 +191,12 @@ def plot():
a.set_xlabel(r"True $\ln K$")
a.set_ylabel(r"Network $\ln K$")
f.tight_layout()
f.savefig("en_metal.pdf")
f.savefig("en.pdf")


plot()

f, a = plt.subplots()
a.plot(classifier.trace.losses)
a.set_yscale("log")
f.savefig("losses_metal.pdf")
f.savefig("losses.pdf")

0 comments on commit db0e2f7

Please sign in to comment.