Skip to content

Commit

Permalink
make guidance_scale keep float in args
Browse files Browse the repository at this point in the history
  • Loading branch information
Akegarasu committed Aug 29, 2024
1 parent a61cf73 commit 6c0e8a5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def get_noise_pred_and_target(
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)

# get guidance
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)

# ensure the hidden state will require grad
if args.gradient_checkpointing:
Expand Down

0 comments on commit 6c0e8a5

Please sign in to comment.