Skip to content

Commit

Permalink
make default schedule more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
yallup committed Jul 15, 2024
1 parent 632f3cd commit f82c81b
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions clax/clax.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,32 +104,37 @@ def _init_state(self, **kwargs):
optimizer = kwargs.get("optimizer", None)
params = _params["params"]
batch_stats = _params["batch_stats"]
transition_steps = kwargs.get("transition_steps", 1000)

target_batches_per_epoch = kwargs.pop("target_batches_per_epoch")
warmup_fraction = kwargs.get("warmup_fraction", 0.05)
cold_fraction = kwargs.get("cold_fraction", 0.05)
cold_lr = kwargs.get("cold_lr", 1e-3)
epochs = kwargs.pop("epochs")

self.schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=lr,
warmup_steps=transition_steps,
decay_steps=transition_steps * 10,
end_value=lr * 1e-4,
warmup_steps=int(warmup_fraction * target_batches_per_epoch * epochs),
decay_steps=int((1 - cold_fraction) * target_batches_per_epoch * epochs),
end_value=lr * cold_lr,
exponent=1.0,
)
if not optimizer:
optimizer = optax.chain(
# optax.clip_by_global_norm(1.0),
optax.adaptive_grad_clip(0.01),
# optax.contrib.schedule_free_adamw(lr, warmup_steps=transition_steps)
optax.adamw(self.schedule),
# optax.adamw(lr),
)

# self.state = train_state.TrainState.create(
self.state = TrainState.create(
apply_fn=self.network.apply,
params=params,
batch_stats=batch_stats,
tx=optimizer,
)

def fit(self, samples, labels, **kwargs):
def fit(self, samples, labels, epochs=10, **kwargs):
"""Fit the classifier on provided samples.
Args:
Expand All @@ -141,15 +146,23 @@ def fit(self, samples, labels, **kwargs):
batch_size (int): Size of the training batches. Defaults to 1024.
epochs (int): Number of training epochs. Defaults to 10.
lr (float): Learning rate. Defaults to 1e-2.
transition_steps (int): Number of steps to transition the learning rate.
Defaults to 100.
optimizer (optax): Optimizer to use. Defaults to None. If none uses AdamW with a cosine schedule.
with adjustable parameters as further kwargs:
warmup_fraction (float): Fraction of the training steps to warm up the learning rate.
Defaults to 0.05.
cold_fraction (float): Fraction of the training steps at the cold learning rate.
Defaults to 0.05.
cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3.
"""
restart = kwargs.get("restart", False)
batch_size = kwargs.get("batch_size", 1024)
data_size = samples.shape[0]
batches_per_epoch = data_size // batch_size
self.ndims = samples.shape[-1]
kwargs["epochs"] = epochs
if (not self.state) | restart:
kwargs["target_batches_per_epoch"] = batches_per_epoch
self._init_state(**kwargs)
labels = jnp.array(labels, dtype=int)
samples = jnp.array(samples, dtype=jnp.float32)
Expand All @@ -168,7 +181,6 @@ def predict(self, samples):
Args:
samples (np.ndarray): Samples to predict on.
log (bool): If True, return the log-probabilities. Defaults to True.
"""
return self._predict_weight(samples)

Expand Down Expand Up @@ -198,27 +210,36 @@ def loss(self, params, batch_stats, batch, labels, rng):
loss = self.loss_fn(output.squeeze(), labels).mean()
return loss, updates

def fit(self, samples_a, samples_b, **kwargs):
def fit(self, samples_a, samples_b, epochs=10, **kwargs):
"""Fit the classifier on provided samples.
Args:
samples (np.ndarray): Samples to train on.
labels (np.array): integer class labels corresponding to each sample.
Keyword Args:
restart (bool): If True, reinitialise the model before training. Defaults to False.
batch_size (int): Size of the training batches. Defaults to 1024.
epochs (int): Number of training epochs. Defaults to 10.
lr (float): Learning rate. Defaults to 1e-2.
transition_steps (int): Number of steps to transition the learning rate.
Defaults to 100.
optimizer (optax): Optimizer to use. Defaults to None. If none uses AdamW with a cosine schedule.
with adjustable parameters as further kwargs:
warmup_fraction (float): Fraction of the training steps to warm up the learning rate.
Defaults to 0.05.
cold_fraction (float): Fraction of the training steps at the cold learning rate.
Defaults to 0.05.
cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3.
"""
restart = kwargs.get("restart", False)
batch_size = kwargs.get("batch_size", 1024)
self.ndims = kwargs.get("ndims", samples_a.shape[-1])
data_size = samples_a.shape[0]
batches_per_epoch = data_size // batch_size
kwargs["epochs"] = epochs
if (not self.state) | restart:
kwargs["target_batches_per_epoch"] = batches_per_epoch
self._init_state(**kwargs)
self._train(samples_a, samples_b, batches_per_epoch, **kwargs)
self._predict_weight = lambda x: self.state.apply_fn(
Expand All @@ -230,15 +251,6 @@ def fit(self, samples_a, samples_b, **kwargs):
train=False,
)

def predict(self, samples):
"""Predict the class (log) - probabilities for the provided samples.
Args:
samples (np.ndarray): Samples to predict on.
log (bool): If True, return the log-probabilities. Defaults to True.
"""
return self._predict_weight(samples)


class Regressor(Classifier):
"""Regressor class wrapping a basic jax multiclass regressor."""
Expand Down

0 comments on commit f82c81b

Please sign in to comment.