Skip to content

Commit

Permalink
[FIX] feed generator instead of seed to pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
markkua authored Apr 24, 2024
1 parent ecf8a9e commit 0a96f78
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
21 changes: 9 additions & 12 deletions marigold/marigold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __call__(
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
seed: Union[int, None] = None,
generator: Union[torch.Generator, None] = None,
color_map: str = "Spectral",
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
Expand All @@ -146,8 +146,8 @@ def __call__(
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
seed (`int`, *optional*, defaults to `None`)
Reproducibility seed.
generator (`torch.Generator`, *optional*, defaults to `None`)
Random generator for initial noise generation.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
Expand Down Expand Up @@ -228,7 +228,7 @@ def __call__(
rgb_in=batched_img,
num_inference_steps=denoising_steps,
show_pbar=show_progress_bar,
seed=seed,
generator=generator,
)
depth_pred_ls.append(depth_pred_raw.detach())
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
Expand Down Expand Up @@ -322,7 +322,7 @@ def single_infer(
self,
rgb_in: torch.Tensor,
num_inference_steps: int,
seed: Union[int, None],
generator: Union[torch.Generator, None],
show_pbar: bool,
) -> torch.Tensor:
"""
Expand All @@ -335,6 +335,8 @@ def single_infer(
Number of diffusion denoisign steps (DDIM) during inference.
show_pbar (`bool`):
Display a progress bar of diffusion denoising.
generator (`torch.Generator`)
Random generator for initial noise generation.
Returns:
`torch.Tensor`: Predicted depth map.
"""
Expand All @@ -349,16 +351,11 @@ def single_infer(
rgb_latent = self.encode_rgb(rgb_in)

# Initial depth map (noise)
if seed is None:
rand_num_generator = None
else:
rand_num_generator = torch.Generator(device=device)
rand_num_generator.manual_seed(seed)
depth_latent = torch.randn(
rgb_latent.shape,
device=device,
dtype=self.dtype,
generator=rand_num_generator,
generator=generator,
) # [B, 4, h, w]

# Batched empty text embedding
Expand Down Expand Up @@ -391,7 +388,7 @@ def single_infer(

# compute the previous noisy sample x_t -> x_t-1
depth_latent = self.scheduler.step(
noise_pred, t, depth_latent, generator=rand_num_generator
noise_pred, t, depth_latent, generator=generator
).prev_sample

depth = self.decode_depth(depth_latent)
Expand Down
9 changes: 8 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@
# Read input image
input_image = Image.open(rgb_path)

# Random number generator
if seed is None:
generator = None
else:
generator = torch.Generator(device=device)
generator.manual_seed(seed)

# Predict depth
pipe_out = pipe(
input_image,
Expand All @@ -240,7 +247,7 @@
color_map=color_map,
show_progress_bar=True,
resample_method=resample_method,
seed=seed,
generator=generator,
)

depth_pred: np.ndarray = pipe_out.depth_np
Expand Down

0 comments on commit 0a96f78

Please sign in to comment.