diff --git a/clax/clax.py b/clax/clax.py index 41b3378..9fc8c39 100644 --- a/clax/clax.py +++ b/clax/clax.py @@ -104,24 +104,29 @@ def _init_state(self, **kwargs): optimizer = kwargs.get("optimizer", None) params = _params["params"] batch_stats = _params["batch_stats"] - transition_steps = kwargs.get("transition_steps", 1000) + + target_batches_per_epoch = kwargs.pop("target_batches_per_epoch") + warmup_fraction = kwargs.get("warmup_fraction", 0.05) + cold_fraction = kwargs.get("cold_fraction", 0.05) + cold_lr = kwargs.get("cold_lr", 1e-3) + epochs = kwargs.pop("epochs") + self.schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=lr, - warmup_steps=transition_steps, - decay_steps=transition_steps * 10, - end_value=lr * 1e-4, + warmup_steps=int(warmup_fraction * target_batches_per_epoch * epochs), + decay_steps=int((1 - cold_fraction) * target_batches_per_epoch * epochs), + end_value=lr * cold_lr, exponent=1.0, ) if not optimizer: optimizer = optax.chain( - # optax.clip_by_global_norm(1.0), optax.adaptive_grad_clip(0.01), + # optax.contrib.schedule_free_adamw(lr, warmup_steps=transition_steps) optax.adamw(self.schedule), # optax.adamw(lr), ) - # self.state = train_state.TrainState.create( self.state = TrainState.create( apply_fn=self.network.apply, params=params, @@ -129,7 +134,7 @@ def _init_state(self, **kwargs): tx=optimizer, ) - def fit(self, samples, labels, **kwargs): + def fit(self, samples, labels, epochs=10, **kwargs): """Fit the classifier on provided samples. Args: @@ -141,15 +146,23 @@ def fit(self, samples, labels, **kwargs): batch_size (int): Size of the training batches. Defaults to 1024. epochs (int): Number of training epochs. Defaults to 10. lr (float): Learning rate. Defaults to 1e-2. - transition_steps (int): Number of steps to transition the learning rate. - Defaults to 100. + + optimizer (optax): Optimizer to use. Defaults to None. If none uses AdamW with a cosine schedule. + with adjustable parameters as further kwargs: + warmup_fraction (float): Fraction of the training steps to warm up the learning rate. + Defaults to 0.05. + cold_fraction (float): Fraction of the training steps at the cold learning rate. + Defaults to 0.05. + cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3. """ restart = kwargs.get("restart", False) batch_size = kwargs.get("batch_size", 1024) data_size = samples.shape[0] batches_per_epoch = data_size // batch_size self.ndims = samples.shape[-1] + kwargs["epochs"] = epochs if (not self.state) | restart: + kwargs["target_batches_per_epoch"] = batches_per_epoch self._init_state(**kwargs) labels = jnp.array(labels, dtype=int) samples = jnp.array(samples, dtype=jnp.float32) @@ -168,7 +181,6 @@ def predict(self, samples): Args: samples (np.ndarray): Samples to predict on. - log (bool): If True, return the log-probabilities. Defaults to True. """ return self._predict_weight(samples) @@ -198,27 +210,36 @@ def loss(self, params, batch_stats, batch, labels, rng): loss = self.loss_fn(output.squeeze(), labels).mean() return loss, updates - def fit(self, samples_a, samples_b, **kwargs): + def fit(self, samples_a, samples_b, epochs=10, **kwargs): """Fit the classifier on provided samples. Args: samples (np.ndarray): Samples to train on. labels (np.array): integer class labels corresponding to each sample. + Keyword Args: restart (bool): If True, reinitialise the model before training. Defaults to False. batch_size (int): Size of the training batches. Defaults to 1024. epochs (int): Number of training epochs. Defaults to 10. lr (float): Learning rate. Defaults to 1e-2. - transition_steps (int): Number of steps to transition the learning rate. - Defaults to 100. + + optimizer (optax): Optimizer to use. Defaults to None. If none uses AdamW with a cosine schedule. + with adjustable parameters as further kwargs: + warmup_fraction (float): Fraction of the training steps to warm up the learning rate. + Defaults to 0.05. + cold_fraction (float): Fraction of the training steps at the cold learning rate. + Defaults to 0.05. + cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3. """ restart = kwargs.get("restart", False) batch_size = kwargs.get("batch_size", 1024) self.ndims = kwargs.get("ndims", samples_a.shape[-1]) data_size = samples_a.shape[0] batches_per_epoch = data_size // batch_size + kwargs["epochs"] = epochs if (not self.state) | restart: + kwargs["target_batches_per_epoch"] = batches_per_epoch self._init_state(**kwargs) self._train(samples_a, samples_b, batches_per_epoch, **kwargs) self._predict_weight = lambda x: self.state.apply_fn( @@ -230,15 +251,6 @@ def fit(self, samples_a, samples_b, **kwargs): train=False, ) - def predict(self, samples): - """Predict the class (log) - probabilities for the provided samples. - - Args: - samples (np.ndarray): Samples to predict on. - log (bool): If True, return the log-probabilities. Defaults to True. - """ - return self._predict_weight(samples) - class Regressor(Classifier): """Regressor class wrapping a basic jax multiclass regressor."""