Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pipeline variables #65

Merged
merged 6 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,10 @@ Activate the environment again after restarting the terminal session.

### 🚀 Run inference with LCM (faster)

The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 to 4:
The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 (default) to 4. Run with default LCM setting:

```bash
python run.py \
--denoise_steps 4 \
--ensemble_size 5 \
--input_rgb_dir input/in-the-wild_example \
--output_dir output/in-the-wild_example_lcm
```
Expand All @@ -156,11 +154,11 @@ The default settings are optimized for the best result. However, the behavior of

- Trade-offs between the **accuracy** and **speed** (for both options, larger values result in better accuracy at the cost of slower inference.)
- `--ensemble_size`: Number of inference passes in the ensemble. For LCM `ensemble_size` is more important than `denoise_steps`. Default: ~~10~~ 5 (for LCM).
- `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. Default: ~~10~~ 4 (for LCM).
- `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. When unassigned (`None`), will read default setting from model config. Default: ~~10 4 (for LCM)~~ `None`.

- By default, the inference script resizes input images to the *processing resolution*, and then resizes the prediction back to the original resolution. This gives the best quality, as Stable Diffusion, from which Marigold is derived, performs best at 768x768 resolution.

- `--processing_res`: the processing resolution; set 0 to process the input resolution directly. Default: 768.
- `--processing_res`: the processing resolution; set as 0 to process the input resolution directly. When unassigned (`None`), will read default setting from model config. Default: ~~768~~ `None`.
- `--output_processing_res`: produce output at the processing resolution instead of upsampling it to the input resolution. Default: False.
- `--resample_method`: resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`.

Expand Down
5 changes: 4 additions & 1 deletion infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Last modified: 2024-04-15
# Last modified: 2024-05-24
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -213,6 +213,9 @@ def check_directory(directory):
logging.debug("run without xformers")

pipe = pipe.to(device)
logging.info(
f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}"
)

# -------------------- Inference and saving --------------------
with torch.no_grad():
Expand Down
109 changes: 79 additions & 30 deletions marigold/marigold_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
# Last modified: 2024-05-24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +20,7 @@


import logging
from typing import Dict, Union
from typing import Dict, Optional, Union

import numpy as np
import torch
Expand All @@ -33,13 +34,13 @@
from diffusers.utils import BaseOutput
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms.functional import resize, pil_to_tensor
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import pil_to_tensor, resize
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from .util.batchsize import find_batch_size
from .util.ensemble import ensemble_depths
from .util.ensemble import ensemble_depth
from .util.image_util import (
chw2hwc,
colorize_depth_maps,
Expand Down Expand Up @@ -85,6 +86,25 @@ class MarigoldPipeline(DiffusionPipeline):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
scale_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
the model config. When used together with the `shift_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
shift_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
the model config. When used together with the `scale_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
default_denoising_steps (`int`, *optional*):
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
quality with the given model. This value must be set in the model config. When the pipeline is called
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
default_processing_resolution (`int`, *optional*):
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
default value is used. This is required to ensure reasonable results with various model flavors trained
with varying optimal processing resolution values.
"""

rgb_latent_scale_factor = 0.18215
Expand All @@ -97,26 +117,40 @@ def __init__(
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
scale_invariant: Optional[bool] = True,
shift_invariant: Optional[bool] = True,
default_denoising_steps: Optional[int] = None,
default_processing_resolution: Optional[int] = None,
):
super().__init__()

self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
self.register_to_config(
scale_invariant=scale_invariant,
shift_invariant=shift_invariant,
default_denoising_steps=default_denoising_steps,
default_processing_resolution=default_processing_resolution,
)

self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant
self.default_denoising_steps = default_denoising_steps
self.default_processing_resolution = default_processing_resolution

self.empty_text_embed = None

@torch.no_grad()
def __call__(
self,
input_image: Union[Image.Image, torch.Tensor],
denoising_steps: int = 10,
ensemble_size: int = 10,
processing_res: int = 768,
denoising_steps: Optional[int] = None,
ensemble_size: int = 5,
processing_res: Optional[int] = None,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
Expand All @@ -131,18 +165,21 @@ def __call__(
Args:
input_image (`Image`):
Input RGB (or gray-scale) image.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
denoising_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
for Marigold-LCM models.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
processing_res (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, processes at the original image resolution. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
value `None` resolves to the optimal value from the model config.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
Expand All @@ -152,6 +189,10 @@ def __call__(
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
Colormap used to colorize the depth map.
scale_invariant (`str`, *optional*, defaults to `True`):
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
shift_invariant (`str`, *optional*, defaults to `True`):
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
Expand All @@ -161,6 +202,12 @@ def __call__(
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
coming from ensembling. None if `ensemble_size = 1`
"""
# Model-specific optimal default values leading to fast and reasonable results.
if denoising_steps is None:
denoising_steps = self.default_denoising_steps
if processing_res is None:
processing_res = self.default_processing_resolution

assert processing_res >= 0
assert ensemble_size >= 1

Expand All @@ -175,14 +222,15 @@ def __call__(
input_image = input_image.convert("RGB")
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
rgb = pil_to_tensor(input_image)
rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
elif isinstance(input_image, torch.Tensor):
rgb = input_image.squeeze()
rgb = input_image
else:
raise TypeError(f"Unknown input type: {type(input_image) = }")
input_size = rgb.shape
assert (
3 == rgb.dim() and 3 == input_size[0]
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
4 == rgb.dim() and 3 == input_size[-3]
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"

# Resize image
if processing_res > 0:
Expand All @@ -199,7 +247,7 @@ def __call__(

# ----------------- Predicting depth -----------------
# Batch repeated input image
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
single_rgb_dataset = TensorDataset(duplicated_rgb)
if batch_size > 0:
_bs = batch_size
Expand Down Expand Up @@ -231,35 +279,36 @@ def __call__(
generator=generator,
)
depth_pred_ls.append(depth_pred_raw.detach())
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
depth_preds = torch.concat(depth_pred_ls, dim=0)
torch.cuda.empty_cache() # clear vram cache for ensembling

# ----------------- Test-time ensembling -----------------
if ensemble_size > 1:
depth_pred, pred_uncert = ensemble_depths(
depth_preds, **(ensemble_kwargs or {})
depth_pred, pred_uncert = ensemble_depth(
depth_preds,
scale_invariant=self.scale_invariant,
shift_invariant=self.shift_invariant,
max_res=50,
**(ensemble_kwargs or {}),
)
else:
depth_pred = depth_preds
pred_uncert = None

# ----------------- Post processing -----------------
# Scale prediction to [0, 1]
min_d = torch.min(depth_pred)
max_d = torch.max(depth_pred)
depth_pred = (depth_pred - min_d) / (max_d - min_d)

# Resize back to original resolution
if match_input_res:
depth_pred = resize(
depth_pred.unsqueeze(0),
input_size[1:],
depth_pred,
input_size[-2:],
interpolation=resample_method,
antialias=True,
).squeeze()
)

# Convert to numpy
depth_pred = depth_pred.squeeze()
depth_pred = depth_pred.cpu().numpy()
if pred_uncert is not None:
pred_uncert = pred_uncert.squeeze().cpu().numpy()

# Clip output range
depth_pred = depth_pred.clip(0, 1)
Expand Down
Loading
Loading