Skip to content

Commit

Permalink
fine tune
Browse files Browse the repository at this point in the history
  • Loading branch information
yallup committed Jul 24, 2024
1 parent b870e55 commit d46bb51
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
47 changes: 47 additions & 0 deletions clax/clax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.random as random
import optax
from jax import jit
from jaxopt import LBFGS
from tqdm import tqdm

from clax.network import DataLoader, Network, TrainState
Expand Down Expand Up @@ -180,6 +181,28 @@ def fit(self, samples, labels, epochs=10, batch_size=1024, **kwargs):
train=False,
)

def fine_tune(self, samples, labels, **kwargs):
"""Fine tune the classifier with a 2nd order solver"""
labels = jnp.array(labels, dtype=int)
samples = jnp.array(samples, dtype=jnp.float32)
if not self.state:
raise ValueError("Classifier not trained yet.")

def fn(x):
return self.loss(x, self.state.batch_stats, samples, labels, self.rng)[0]

tuner = LBFGS(fn)
res = tuner.run(self.state.params)
self.state = self.state.replace(params=res[0])
self._predict_weight = lambda x: self.state.apply_fn(
{
"params": self.state.params,
"batch_stats": self.state.batch_stats,
},
x,
train=False,
)

def predict(self, samples):
"""Predict the class (log) - probabilities for the provided samples.
Expand Down Expand Up @@ -257,6 +280,30 @@ def fit(self, samples_a, samples_b, epochs=10, batch_size=1024, **kwargs):
train=False,
)

def fine_tune(self, samples_a, samples_b, **kwargs):
"""Fine tune the classifier with a 2nd order solver"""
samples_a = jnp.array(samples_a, dtype=jnp.float32)
samples_b = jnp.array(samples_b, dtype=jnp.float32)
if not self.state:
raise ValueError("Classifier not trained yet.")

def fn(x):
return self.loss(x, self.state.batch_stats, samples_a, samples_b, self.rng)[
0
]

tuner = LBFGS(fn)
res = tuner.run(self.state.params)
self.state = self.state.replace(params=res[0])
self._predict_weight = lambda x: self.state.apply_fn(
{
"params": self.state.params,
"batch_stats": self.state.batch_stats,
},
x,
train=False,
)


class Regressor(Classifier):
"""Regressor class wrapping a basic jax multiclass regressor."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"flax >= 0.8.2",
"tqdm >= 4.62.0",
"optax >= 0.2.2",
"jaxopt",
]

[options.extras_require]
Expand Down

0 comments on commit d46bb51

Please sign in to comment.