Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yallup committed Jul 15, 2024
1 parent f82c81b commit b870e55
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions clax/clax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def loss(self, params, batch_stats, batch, labels, rng):
def _train(self, samples, labels, batches_per_epoch, **kwargs):
"""Internal wrapping of training loop."""
self.trace = Trace()
batch_size = kwargs.get("batch_size", 1024)
epochs = kwargs.get("epochs", 10)
batch_size = kwargs.get("batch_size")
epochs = kwargs.get("epochs")
# epochs *= batches_per_epoch

@jit
Expand Down Expand Up @@ -114,8 +114,10 @@ def _init_state(self, **kwargs):
self.schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=lr,
warmup_steps=int(warmup_fraction * target_batches_per_epoch * epochs),
decay_steps=int((1 - cold_fraction) * target_batches_per_epoch * epochs),
warmup_steps=int(warmup_fraction * target_batches_per_epoch * epochs + 1),
decay_steps=int(
(1 - cold_fraction) * target_batches_per_epoch * epochs + 1
),
end_value=lr * cold_lr,
exponent=1.0,
)
Expand All @@ -134,7 +136,7 @@ def _init_state(self, **kwargs):
tx=optimizer,
)

def fit(self, samples, labels, epochs=10, **kwargs):
def fit(self, samples, labels, epochs=10, batch_size=1024, **kwargs):
"""Fit the classifier on provided samples.
Args:
Expand All @@ -156,8 +158,10 @@ def fit(self, samples, labels, epochs=10, **kwargs):
cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3.
"""
restart = kwargs.get("restart", False)
batch_size = kwargs.get("batch_size", 1024)

data_size = samples.shape[0]
batch_size = min(batch_size, data_size)
kwargs["batch_size"] = batch_size
batches_per_epoch = data_size // batch_size
self.ndims = samples.shape[-1]
kwargs["epochs"] = epochs
Expand Down Expand Up @@ -210,7 +214,7 @@ def loss(self, params, batch_stats, batch, labels, rng):
loss = self.loss_fn(output.squeeze(), labels).mean()
return loss, updates

def fit(self, samples_a, samples_b, epochs=10, **kwargs):
def fit(self, samples_a, samples_b, epochs=10, batch_size=1024, **kwargs):
"""Fit the classifier on provided samples.
Args:
Expand All @@ -233,9 +237,11 @@ def fit(self, samples_a, samples_b, epochs=10, **kwargs):
cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3.
"""
restart = kwargs.get("restart", False)
batch_size = kwargs.get("batch_size", 1024)
self.ndims = kwargs.get("ndims", samples_a.shape[-1])
data_size = samples_a.shape[0]

batch_size = min(batch_size, data_size)
kwargs["batch_size"] = batch_size
batches_per_epoch = data_size // batch_size
kwargs["epochs"] = epochs
if (not self.state) | restart:
Expand Down

0 comments on commit b870e55

Please sign in to comment.