Skip to content

Commit

Permalink
feat: Add shift option to --timestep_sampling in FLUX.1 fine-tuni…
Browse files Browse the repository at this point in the history
…ng and LoRA training
  • Loading branch information
kohya-ss committed Aug 25, 2024
1 parent ea92426 commit 72287d3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`

Aug 25, 2024:
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.
Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0`

Aug 24, 2024 (update 2):

__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available).
Expand Down
15 changes: 13 additions & 2 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps(
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))

This comment has been minimized.

Copy link
@bghira

bghira Aug 27, 2024

fwiw x-labs got their reference from cloneofsimo/minRF the gold standard flow-matching diffusion implementation

This comment has been minimized.

Copy link
@kohya-ss

kohya-ss Aug 27, 2024

Author Owner

Thank you, I didn't understand the reason of sigmoid, so this clears things up.

This comment has been minimized.

Copy link
@bghira

bghira Aug 27, 2024

someone on my tracker pointed out that the shift value should be resolution dependent https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/sampling.py#L66

This comment has been minimized.

Copy link
@kohya-ss

kohya-ss Aug 27, 2024

Author Owner

Thanks, I was also told this by another person. 3.1582 seems to be the value for a resolution of 1024x1024. I don't know how much of an impact it will have, but it may be a good idea to set the value according to the resolution of the batch.

else:
t = torch.rand((bsz,), device=device)

timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)

t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand Down Expand Up @@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):

parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid"],
choices=["sigma", "uniform", "sigmoid", "shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
)
parser.add_argument(
"--sigmoid_scale",
Expand Down

0 comments on commit 72287d3

Please sign in to comment.