Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip progressbar if needed #516

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions mrmustard/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from itertools import chain, groupby
from typing import List, Callable, Sequence, Union, Mapping, Dict
from mrmustard import math
from mrmustard import math, settings
from mrmustard.math.parameters import Constant, Variable
from mrmustard.training.callbacks import Callback
from mrmustard.training.progress_bar import ProgressBar
Expand Down Expand Up @@ -98,35 +98,58 @@ def minimize(
def _minimize(self, cost_fn, by_optimizing, max_steps, callbacks):
# finding out which parameters are trainable from the ops
trainable_params = self._get_trainable_params(by_optimizing)
if settings.PROGRESSBAR:
bar = ProgressBar(max_steps)
with bar:
self._optimization_loop(cost_fn, trainable_params, max_steps, callbacks, bar)
else:
self._optimization_loop(cost_fn, trainable_params, max_steps, callbacks)

def _optimization_loop(
self, cost_fn, trainable_params, max_steps, callbacks, progress_bar=None
):
"""Internal method that performs the main optimization loop.

Args:
cost_fn (Callable): The cost function to minimize
trainable_params (dict): Dictionary of trainable parameters
max_steps (int): Maximum number of optimization steps
callbacks (dict): Dictionary of callback functions to execute during optimization
progress_bar (ProgressBar, optional): Progress bar instance for displaying optimization progress.
If None, no progress will be displayed. Defaults to None.

Note:
This method maintains internal state in self.opt_history and self.callback_history,
tracking the optimization progress and callback results respectively.
"""
cost_fn_modified = False
orig_cost_fn = cost_fn

bar = ProgressBar(max_steps)
with bar:
while not self.should_stop(max_steps):
cost, grads = self.compute_loss_and_gradients(cost_fn, trainable_params.values())
while not self.should_stop(max_steps):
cost, grads = self.compute_loss_and_gradients(cost_fn, trainable_params.values())

trainables = {tag: (x, dx) for (tag, x), dx in zip(trainable_params.items(), grads)}
trainables = {tag: (x, dx) for (tag, x), dx in zip(trainable_params.items(), grads)}

if cost_fn_modified:
self.callback_history["orig_cost"].append(orig_cost_fn())
if cost_fn_modified:
self.callback_history["orig_cost"].append(orig_cost_fn())

new_cost_fn, new_grads = self._run_callbacks(
callbacks=callbacks,
cost_fn=cost_fn,
cost=cost,
trainables=trainables,
)
new_cost_fn, new_grads = self._run_callbacks(
callbacks=callbacks,
cost_fn=cost_fn,
cost=cost,
trainables=trainables,
)

self.apply_gradients(trainable_params.values(), new_grads or grads)
self.opt_history.append(cost)
bar.step(math.asnumpy(cost))
self.apply_gradients(trainable_params.values(), new_grads or grads)
self.opt_history.append(cost)
if progress_bar is not None:
progress_bar.step(math.asnumpy(cost))

if callable(new_cost_fn):
cost_fn = new_cost_fn
if not cost_fn_modified:
cost_fn_modified = True
self.callback_history["orig_cost"] = self.opt_history.copy()
if callable(new_cost_fn):
cost_fn = new_cost_fn
if not cost_fn_modified:
cost_fn_modified = True
self.callback_history["orig_cost"] = self.opt_history.copy()

def apply_gradients(self, trainable_params, grads):
"""Apply gradients to variables.
Expand Down
Loading