diff --git a/clax/_version.py b/clax/_version.py index b1a19e3..034f46c 100644 --- a/clax/_version.py +++ b/clax/_version.py @@ -1 +1 @@ -__version__ = "0.0.5" +__version__ = "0.0.6" diff --git a/clax/clax.py b/clax/clax.py index edf3e47..aba93e9 100644 --- a/clax/clax.py +++ b/clax/clax.py @@ -38,7 +38,9 @@ def __init__(self, n=1, **kwargs): self.rng = random.PRNGKey(kwargs.get("seed", 2024)) self.n = n self.network = Network(n_out=n) + self.n = n self.state = None + self.dl = DataLoader if n == 1: self.loss_fn = optax.sigmoid_binary_cross_entropy else: @@ -60,7 +62,7 @@ def _train(self, samples, labels, batches_per_epoch, **kwargs): self.trace = Trace() batch_size = kwargs.get("batch_size", 1024) epochs = kwargs.get("epochs", 10) - epochs *= batches_per_epoch + # epochs *= batches_per_epoch @jit def update_step(state, samples, labels, rng): @@ -74,23 +76,25 @@ def update_step(state, samples, labels, rng): train_size = samples.shape[0] batch_size = min(batch_size, train_size) losses = [] - map = DataLoader(samples, labels) + + dl = self.dl(samples.shape[0], labels.shape[0], **kwargs) tepochs = tqdm(range(epochs)) for k in tepochs: - self.rng, step_rng = random.split(self.rng) - perm, _ = map.sample(batch_size) - batch = samples[perm] - batch_label = labels[perm] - loss, self.state = update_step(self.state, batch, batch_label, step_rng) - losses.append(loss) - # self.state.losses.append(loss) - if (k + 1) % 50 == 0: - ma = jnp.mean(jnp.array(losses[-50:])) - self.trace.losses.append(ma) - tepochs.set_postfix(loss="{:.2e}".format(ma)) - self.trace.iteration += 1 - # lr_scale = otu.tree_get(self.state, "scale") - # self.trace.lr.append(lr_scale) + epoch_losses = [] + for _ in range(batches_per_epoch): + self.rng, step_rng = random.split(self.rng) + perm, perm_label = dl.sample(batch_size) + batch = samples[perm] + batch_label = labels[perm_label] + loss, self.state = update_step(self.state, batch, batch_label, step_rng) + epoch_losses.append(loss) + + epoch_summary_loss = jnp.mean(jnp.asarray(epoch_losses)) + tepochs.set_postfix(loss="{:.2e}".format(epoch_summary_loss)) + losses.append(epoch_summary_loss) + # if losses[::-1][:patience] < epoch_summary_loss: + # break + self.trace.losses = jnp.asarray(losses) def _init_state(self, **kwargs): """Initialise the training state and setup the optimizer.""" diff --git a/clax/network.py b/clax/network.py index 857c696..6a4504e 100644 --- a/clax/network.py +++ b/clax/network.py @@ -6,15 +6,15 @@ class DataLoader(object): - def __init__(self, x0, x1, rng=0): - self.x0 = np.atleast_2d(x0) - self.x1 = np.atleast_2d(x1) + def __init__(self, x0, x1, rng=0, **kwargs): + self.x0 = x0 + self.x1 = x1 self.rng = np.random.default_rng(rng) def sample(self, batch_size=128, *args): - idx = self.rng.choice(self.x0.shape[0], size=(batch_size), replace=True) - idx_p = self.rng.choice(self.x1.shape[0], size=(batch_size), replace=True) - return idx, idx_p + idx = self.rng.choice(self.x0, size=(batch_size), replace=True) + # idx_p = self.rng.choice(self.x1.shape[0], size=(batch_size), replace=True) + return idx, idx class TrainState(train_state.TrainState): @@ -33,6 +33,7 @@ class Network(nn.Module): @nn.compact def __call__(self, x, train: bool): x = nn.Dense(self.n_initial)(x) + # nn.BatchNorm(use_running_average=not train)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.silu(x) for i in range(self.n_layers): diff --git a/examples/bayes_factors.py b/examples/bayes_factors.py index 89de247..1f2be27 100644 --- a/examples/bayes_factors.py +++ b/examples/bayes_factors.py @@ -49,11 +49,11 @@ classifier = Classifier() chain = optax.chain( - optax.adaptive_grad_clip(10.0), - optax.adamw(1e-5), + optax.adaptive_grad_clip(1.0), + optax.adamw(1e-3), ) -classifier.fit(X_train, y_train, epochs=2000, optimizer=chain, batch_size=10000) +classifier.fit(X_train, y_train, epochs=500, optimizer=chain, batch_size=1000) true_k = M_1.logpdf(X_test) - M_0.logpdf(X_test) network_k = classifier.predict(X_test).squeeze()