Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial patience & delta mode to early_stopping #14

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 69 additions & 24 deletions ciclo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
max = auto()


class DeltaMode(str, Enum):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the same convention as the existing mode argument (now called optimization_mode) , i.e., using an Enum to represent the options. For consistency, perhaps it makes sense to define a similar Enum for the InnerLoopAggregation options introduced in #9? (Or we could use that PR's approach of defining a string Literal type for the different options?)

absolute = auto()
relative = auto()


def _transpose_history(
log_history: History,
) -> Mapping[Collection, Mapping[Entry, List[Any]]]:
Expand Down Expand Up @@ -176,17 +181,17 @@
keep_every_n_steps: Optional[int] = None,
async_manager: Optional[flax_checkpoints.AsyncManager] = None,
monitor: Optional[str] = None,
mode: Union[str, OptimizationMode] = "min",
optimization_mode: Union[str, OptimizationMode] = "min",
):
if isinstance(mode, str):
mode = OptimizationMode[mode]
if isinstance(optimization_mode, str):
optimization_mode = OptimizationMode[optimization_mode]

if mode not in OptimizationMode:
if optimization_mode not in OptimizationMode:
raise ValueError(
f"Invalid mode: {mode}, expected one of {list(OptimizationMode)}"
f"Invalid optimization_mode: {optimization_mode}, expected one of {list(OptimizationMode)}"
)
else:
self.mode = mode
self.optimization_mode = optimization_mode

self.ckpt_dir = ckpt_dir
self.prefix = prefix
Expand All @@ -195,7 +200,7 @@
self.keep_every_n_steps = keep_every_n_steps
self.async_manager = async_manager
self.monitor = monitor
self.minimize = self.mode == OptimizationMode.min
self.minimize = self.optimization_mode == OptimizationMode.min
self._best: Optional[float] = None

def __call__(
Expand Down Expand Up @@ -227,7 +232,9 @@
):
self._best = value
step_or_metric = (
value if self.mode == OptimizationMode.max else -value
value
if self.optimization_mode == OptimizationMode.max
else -value
)
else:
save_checkpoint = False
Expand Down Expand Up @@ -264,30 +271,69 @@
self,
monitor: str,
patience: Union[int, Period],
min_delta: float = 0,
mode: Union[str, OptimizationMode] = "min",
initial_patience: Optional[Union[int, Period]] = None,
min_delta: Optional[float] = None,
delta_mode: Union[str, DeltaMode] = "absolute",
optimization_mode: Union[str, OptimizationMode] = "min",
baseline: Optional[float] = None,
restore_best_weights: bool = False,
):
if isinstance(mode, str):
mode = OptimizationMode[mode]
if initial_patience is None:
initial_patience = 1

if min_delta is None:
min_delta = 0.0

if mode not in OptimizationMode:
if isinstance(optimization_mode, str):
optimization_mode = OptimizationMode[optimization_mode]

if optimization_mode not in OptimizationMode:
raise ValueError(
f"Invalid mode: {mode}, expected one of {list(OptimizationMode)}"
f"Invalid mode: {optimization_mode}, expected one of {list(OptimizationMode)}"
)
else:
self.mode = mode

if isinstance(delta_mode, str):
delta_mode = DeltaMode[delta_mode]

if delta_mode not in DeltaMode:
raise ValueError(

Check warning on line 299 in ciclo/callbacks.py

View check run for this annotation

Codecov / codecov/patch

ciclo/callbacks.py#L299

Added line #L299 was not covered by tests
f"Invalid mode: {delta_mode}, expected one of {list(DeltaMode)}"
)

if (
optimization_mode == OptimizationMode.min
and delta_mode == DeltaMode.absolute
):
self.improvement_fn = lambda current, best: current < best - min_delta
elif (
optimization_mode == OptimizationMode.min
and delta_mode == DeltaMode.relative
):
self.improvement_fn = lambda current, best: current < best * (1 - min_delta)

Check warning on line 312 in ciclo/callbacks.py

View check run for this annotation

Codecov / codecov/patch

ciclo/callbacks.py#L312

Added line #L312 was not covered by tests
elif (
optimization_mode == OptimizationMode.max
and delta_mode == DeltaMode.absolute
):
self.improvement_fn = lambda current, best: current > best + min_delta
elif (
optimization_mode == OptimizationMode.max
and delta_mode == DeltaMode.relative
):
self.improvement_fn = lambda current, best: current > best * (1 + min_delta)

self.monitor = monitor
self.patience = (
patience if isinstance(patience, Period) else Period.create(patience)
)
self.initial_patience = (
initial_patience
if isinstance(initial_patience, Period)
else Period.create(initial_patience)
)
self.min_delta = min_delta
self.mode = mode
self.baseline = baseline
self.restore_best_weights = restore_best_weights
self.minimize = self.mode == OptimizationMode.min
self.minimize = optimization_mode == OptimizationMode.min
self._best = baseline
self._best_state = None
self._elapsed_start: Optional[Elapsed] = None
Expand All @@ -306,16 +352,15 @@
except KeyError:
raise ValueError(f"Monitored value '{self.monitor}' not found in logs")

if (
self._best is None
or (self.minimize and value < self._best)
or (not self.minimize and value > self._best)
):
if self._best is None or self.improvement_fn(value, self._best):
self._best = value
self._best_state = state
self._elapsed_start = elapsed

if elapsed - self._elapsed_start >= self.patience:
if (
elapsed - self._elapsed_start >= self.patience
and elapsed >= self.initial_patience
):
if self.restore_best_weights and self._best_state is not None:
state = self._best_state
stop_iteration = True
Expand Down
2 changes: 1 addition & 1 deletion ciclo/loops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def loop(
)
for schedule, callbacks in tasks.items()
]
# prone empty tasks
# prune empty tasks
schedule_callbacks = [x for x in schedule_callbacks if len(x[1]) > 0]

try:
Expand Down
185 changes: 184 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def dummy_inner_loop_fn(_):
return None, log_history, None


class TestCallbacks:
class TestInnerLoop:
def test_inner_loop_default_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
Expand Down Expand Up @@ -133,3 +133,186 @@ def test_inner_loop_aggregation_dict(self):
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}


class TestEarlyStopping:
def test_patience(self):
dataset = jnp.minimum(jnp.arange(10), 5)

def train_step(state, batch):
logs = ciclo.logs()
logs.add_metric("x", batch)
return logs, state

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping("x", optimization_mode="max", patience=1),
],
},
)

assert len(history) == 7

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping("x", optimization_mode="max", patience=3),
],
},
)

assert len(history) == 9

def test_initial_patience(self):
dataset = jnp.maximum(jnp.minimum(jnp.arange(10), 5), 2)

def train_step(state, batch):
logs = ciclo.logs()
logs.add_metric("x", batch)
return logs, state

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping(
"x", optimization_mode="max", patience=1, initial_patience=1
),
],
},
)

assert len(history) == 2

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping(
"x", optimization_mode="max", patience=1, initial_patience=3
),
],
},
)

assert len(history) == 7

def test_min_optimization_mode(self):
dataset = jnp.maximum(jnp.minimum(jnp.arange(9, 0, -1), 6), 3)

def train_step(state, batch):
logs = ciclo.logs()
logs.add_metric("x", batch)
return logs, state

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping(
"x", optimization_mode="min", patience=1, initial_patience=4
),
],
},
)

assert len(history) == 8

def test_min_delta(self):
dataset = jnp.arange(0, 1, 0.1)

def train_step(state, batch):
logs = ciclo.logs()
logs.add_metric("x", batch)
return logs, state

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping(
"x",
optimization_mode="max",
patience=1,
min_delta=0.01,
),
],
},
)

assert len(history) == 10

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping(
"x", optimization_mode="max", patience=1, min_delta=0.1
),
],
},
)

assert len(history) == 2

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping(
"x",
optimization_mode="max",
patience=3,
min_delta=0.05,
),
],
},
)

assert len(history) == 10

def test_min_relative_delta(self):
dataset = jnp.arange(0, 1, 0.1)

def train_step(state, batch):
logs = ciclo.logs()
logs.add_metric("x", batch)
return logs, state

_, history, _ = ciclo.loop(
None,
dataset,
{
ciclo.every(1): [
train_step,
ciclo.early_stopping(
"x",
optimization_mode="max",
patience=1,
min_delta=0.5,
delta_mode="relative",
),
],
},
)

assert len(history) == 4
6 changes: 3 additions & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ def create_state():
ciclo.checkpoint(
f"{logdir}/model",
monitor="accuracy_valid",
mode="max",
optimization_mode="max",
),
ciclo.early_stopping(
monitor="accuracy_valid",
mode="max",
optimization_mode="max",
patience=100,
),
],
Expand Down Expand Up @@ -230,7 +230,7 @@ def __call__(self, x):
ciclo.checkpoint(
f"logdir/{Path(__file__).stem}/{int(time())}",
monitor="accuracy_test",
mode="max",
optimization_mode="max",
),
],
test_dataset=lambda: get_tuple_dataset(batch_size),
Expand Down
Loading