-
Notifications
You must be signed in to change notification settings - Fork 843
Commit
shift
option to --timestep_sampling
in FLUX.1 fine-tuni…
…ng and LoRA training
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps( | |
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
kohya-ss
Author
Owner
|
||
else: | ||
t = torch.rand((bsz,), device=device) | ||
|
||
timesteps = t * 1000.0 | ||
t = t.view(-1, 1, 1, 1) | ||
noisy_model_input = (1 - t) * latents + t * noise | ||
elif args.timestep_sampling == "shift": | ||
shift = args.discrete_flow_shift | ||
logits_norm = torch.randn(bsz, device=device) | ||
timesteps = logits_norm.sigmoid() | ||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) | ||
|
||
t = timesteps.view(-1, 1, 1, 1) | ||
timesteps = timesteps * 1000.0 | ||
noisy_model_input = (1 - t) * latents + t * noise | ||
else: | ||
# Sample a random timestep for each image | ||
# for weighting schemes where we sample timesteps non-uniformly | ||
|
@@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): | |
|
||
parser.add_argument( | ||
"--timestep_sampling", | ||
choices=["sigma", "uniform", "sigmoid"], | ||
choices=["sigma", "uniform", "sigmoid", "shift"], | ||
default="sigma", | ||
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", | ||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." | ||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", | ||
) | ||
parser.add_argument( | ||
"--sigmoid_scale", | ||
|
fwiw x-labs got their reference from cloneofsimo/minRF the gold standard flow-matching diffusion implementation