Skip to content

Commit

Permalink
Merge branch 'dev' into sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Aug 24, 2024
2 parents 5639c2a + d5c076c commit ea92426
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

### Working in progress

- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available.
Expand Down
8 changes: 6 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,11 @@ def optimizer_hook(parameter: torch.Tensor):
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)

target = noise
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise

if (
args.min_snr_gamma
Expand All @@ -718,7 +722,7 @@ def optimizer_hook(parameter: torch.Tensor):
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
Expand Down

0 comments on commit ea92426

Please sign in to comment.