Skip to content

Commit

Permalink
add negative prompt for flux inference script
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 9, 2024
1 parent ce14447 commit d29af14
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 86 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ The command to install PyTorch is as follows:

### Recent Updates

Sep 9, 2024:
Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used.

Sep 5, 2024 (update 1):

Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
Expand Down
289 changes: 203 additions & 86 deletions flux_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,57 @@ def denoise(
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
neg_txt: Optional[torch.Tensor] = None,
neg_vec: Optional[torch.Tensor] = None,
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
# this is ignored for schnell
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

# prepare classifier free guidance
if neg_txt is not None and neg_vec is not None:
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
b_txt = torch.cat([neg_txt, txt], dim=0)
b_vec = torch.cat([neg_vec, vec], dim=0)
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
else:
b_t5_attn_mask = None
else:
b_img_ids = img_ids
b_txt_ids = txt_ids
b_txt = txt
b_vec = vec
b_t5_attn_mask = t5_attn_mask

for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)

# classifier free guidance
if neg_txt is not None and neg_vec is not None:
b_img = torch.cat([img, img], dim=0)
else:
b_img = img

pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
img=b_img,
img_ids=b_img_ids,
txt=b_txt,
txt_ids=b_txt_ids,
y=b_vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
txt_attention_mask=b_t5_attn_mask,
)

# classifier free guidance
if neg_txt is not None and neg_vec is not None:
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
pred = pred_uncond + cfg_scale * (pred - pred_uncond)

img = img + (t_prev - t_curr) * pred

return img
Expand All @@ -106,19 +141,48 @@ def do_sample(
is_schnell: bool,
device: torch.device,
flux_dtype: torch.dtype,
neg_l_pooled: Optional[torch.Tensor] = None,
neg_t5_out: Optional[torch.Tensor] = None,
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
logger.info(f"num_steps: {num_steps}")
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)

# denoise initial noise
if accelerator:
with accelerator.autocast(), torch.no_grad():
x = denoise(
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
model,
img,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps,
guidance,
t5_attn_mask,
neg_t5_out,
neg_l_pooled,
neg_t5_attn_mask,
cfg_scale,
)
else:
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
x = denoise(
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
model,
img,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps,
guidance,
t5_attn_mask,
neg_t5_out,
neg_l_pooled,
neg_t5_attn_mask,
cfg_scale,
)

return x
Expand All @@ -135,6 +199,8 @@ def generate_image(
image_height: int,
steps: Optional[int],
guidance: float,
negative_prompt: Optional[str],
cfg_scale: float,
):
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}")
Expand Down Expand Up @@ -162,65 +228,73 @@ def generate_image(
# txt2img only needs img_ids
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)

# prepare fp8 models
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
clip_l.fp8_prepared = True

if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")

def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)

hidden_states = module.wo(hidden_states)
return hidden_states

return forward

for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)

t5xxl.to(t5xxl_dtype)
prepare_fp8(t5xxl.encoder, torch.bfloat16)
t5xxl.fp8_prepared = True

# prepare embeddings
logger.info("Encoding prompts...")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)
with torch.no_grad():
if is_fp8(clip_l_dtype):
param_itr = clip_l.parameters()
param_itr.__next__() # skip first
param_2nd = param_itr.__next__()
if param_2nd.dtype != clip_l_dtype:
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)

with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)

if is_fp8(t5xxl_dtype):
if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"):
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")

def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)

hidden_states = module.wo(hidden_states)
return hidden_states

return forward

for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)

text_encoder.fp8_prepared = True

t5xxl.to(t5xxl_dtype)
prepare_fp8(t5xxl.encoder, torch.bfloat16)

with accelerator.autocast():
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
def encode(prpt: str):
tokens_and_masks = tokenize_strategy.tokenize(prpt)
with torch.no_grad():
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)

if is_fp8(t5xxl_dtype):
with accelerator.autocast():
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
else:
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
return l_pooled, t5_out, txt_ids, t5_attn_mask

l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
if negative_prompt:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
else:
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None

# NaN check
if torch.isnan(l_pooled).any():
Expand All @@ -244,7 +318,23 @@ def forward(hidden_states):
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None

x = do_sample(
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype
accelerator,
model,
noise,
img_ids,
l_pooled,
t5_out,
txt_ids,
steps,
guidance,
t5_attn_mask,
is_schnell,
device,
flux_dtype,
neg_l_pooled,
neg_t5_out,
neg_t5_attn_mask,
cfg_scale,
)
if args.offload:
model = model.cpu()
Expand Down Expand Up @@ -307,6 +397,8 @@ def forward(hidden_states):
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--lora_weights",
Expand Down Expand Up @@ -403,19 +495,34 @@ def is_fp8(dt):
lora_model.to(device)

lora_models.append(lora_model)

if not args.interactive:
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
else:
# loop for interactive
width = target_width
height = target_height
steps = None
guidance = args.guidance
cfg_scale = args.cfg_scale

while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
)
prompt = input()
if prompt == "":
Expand All @@ -425,26 +532,36 @@ def is_fp8(dt):
options = prompt.split("--")
prompt = options[0].strip()
seed = None
negative_prompt = None
for opt in options[1:]:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("g"):
guidance = float(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))

generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance)
try:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("g"):
guidance = float(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
negative_prompt = ""
elif opt.startswith("c"):
cfg_scale = float(opt[1:].strip())
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")

generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)

logger.info("Done!")

0 comments on commit d29af14

Please sign in to comment.