Skip to content

Commit

Permalink
Merge pull request #2 from yallup/anre
Browse files Browse the repository at this point in the history
Dataloader
  • Loading branch information
yallup authored Jul 5, 2024
2 parents 3e35788 + 23e459b commit b2d75fa
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 26 deletions.
2 changes: 1 addition & 1 deletion clax/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.5"
__version__ = "0.0.6"
36 changes: 20 additions & 16 deletions clax/clax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(self, n=1, **kwargs):
self.rng = random.PRNGKey(kwargs.get("seed", 2024))
self.n = n
self.network = Network(n_out=n)
self.n = n
self.state = None
self.dl = DataLoader
if n == 1:
self.loss_fn = optax.sigmoid_binary_cross_entropy
else:
Expand All @@ -60,7 +62,7 @@ def _train(self, samples, labels, batches_per_epoch, **kwargs):
self.trace = Trace()
batch_size = kwargs.get("batch_size", 1024)
epochs = kwargs.get("epochs", 10)
epochs *= batches_per_epoch
# epochs *= batches_per_epoch

@jit
def update_step(state, samples, labels, rng):
Expand All @@ -74,23 +76,25 @@ def update_step(state, samples, labels, rng):
train_size = samples.shape[0]
batch_size = min(batch_size, train_size)
losses = []
map = DataLoader(samples, labels)

dl = self.dl(samples.shape[0], labels.shape[0], **kwargs)
tepochs = tqdm(range(epochs))
for k in tepochs:
self.rng, step_rng = random.split(self.rng)
perm, _ = map.sample(batch_size)
batch = samples[perm]
batch_label = labels[perm]
loss, self.state = update_step(self.state, batch, batch_label, step_rng)
losses.append(loss)
# self.state.losses.append(loss)
if (k + 1) % 50 == 0:
ma = jnp.mean(jnp.array(losses[-50:]))
self.trace.losses.append(ma)
tepochs.set_postfix(loss="{:.2e}".format(ma))
self.trace.iteration += 1
# lr_scale = otu.tree_get(self.state, "scale")
# self.trace.lr.append(lr_scale)
epoch_losses = []
for _ in range(batches_per_epoch):
self.rng, step_rng = random.split(self.rng)
perm, perm_label = dl.sample(batch_size)
batch = samples[perm]
batch_label = labels[perm_label]
loss, self.state = update_step(self.state, batch, batch_label, step_rng)
epoch_losses.append(loss)

epoch_summary_loss = jnp.mean(jnp.asarray(epoch_losses))
tepochs.set_postfix(loss="{:.2e}".format(epoch_summary_loss))
losses.append(epoch_summary_loss)
# if losses[::-1][:patience] < epoch_summary_loss:
# break
self.trace.losses = jnp.asarray(losses)

def _init_state(self, **kwargs):
"""Initialise the training state and setup the optimizer."""
Expand Down
13 changes: 7 additions & 6 deletions clax/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@


class DataLoader(object):
def __init__(self, x0, x1, rng=0):
self.x0 = np.atleast_2d(x0)
self.x1 = np.atleast_2d(x1)
def __init__(self, x0, x1, rng=0, **kwargs):
self.x0 = x0
self.x1 = x1
self.rng = np.random.default_rng(rng)

def sample(self, batch_size=128, *args):
idx = self.rng.choice(self.x0.shape[0], size=(batch_size), replace=True)
idx_p = self.rng.choice(self.x1.shape[0], size=(batch_size), replace=True)
return idx, idx_p
idx = self.rng.choice(self.x0, size=(batch_size), replace=True)
# idx_p = self.rng.choice(self.x1.shape[0], size=(batch_size), replace=True)
return idx, idx


class TrainState(train_state.TrainState):
Expand All @@ -33,6 +33,7 @@ class Network(nn.Module):
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(self.n_initial)(x)
# nn.BatchNorm(use_running_average=not train)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.silu(x)
for i in range(self.n_layers):
Expand Down
6 changes: 3 additions & 3 deletions examples/bayes_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@
classifier = Classifier()

chain = optax.chain(
optax.adaptive_grad_clip(10.0),
optax.adamw(1e-5),
optax.adaptive_grad_clip(1.0),
optax.adamw(1e-3),
)

classifier.fit(X_train, y_train, epochs=2000, optimizer=chain, batch_size=10000)
classifier.fit(X_train, y_train, epochs=500, optimizer=chain, batch_size=1000)

true_k = M_1.logpdf(X_test) - M_0.logpdf(X_test)
network_k = classifier.predict(X_test).squeeze()
Expand Down

0 comments on commit b2d75fa

Please sign in to comment.