From de12e177f0ad20cf04c28da9fe52777e3d790547 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 26 Aug 2024 14:19:27 -0700 Subject: [PATCH 1/2] Fixes Imagen sampling example and updates container (#868) 1. Fix imagen sampling loop when prompt_ct is a multiple of `batch_size // gen_per_prompt` 2. Add comment explaining the invocation of sampling script of what exactly is expected due to implicit checkpoint dir requirements and quoting 3. add 2B base model generation gin configs 4. parametrize imagen sampling scripts 5. Updates the imagen image with the fix in (1); built as follows: ``` docker buildx build --push -t ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 ./JAX-Toolbox -f - < --- README.md | 2 +- .../diffusion/common/set_gpu_xla_flags.sh | 3 +- rosetta/rosetta/projects/imagen/README.md | 24 +- .../imagen/configs/imagen_1024_sample_2b.gin | 78 +++++++ .../imagen/configs/imagen_256_sample_2b.gin | 220 ++++++++++++++++++ .../rosetta/projects/imagen/imagen_pipe.py | 15 +- .../scripts/example_slurm_inf_train.sub | 4 +- .../imagen/scripts/sample_imagen_1024.sh | 12 +- .../imagen/scripts/sample_imagen_256.sh | 12 +- 9 files changed, 347 insertions(+), 23 deletions(-) create mode 100644 rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin create mode 100644 rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin diff --git a/README.md b/README.md index fe83553df..66d9b2a4e 100644 --- a/README.md +++ b/README.md @@ -277,7 +277,7 @@ We currently support the following frameworks and models. More details about eac | :--- | :---: | :---: | :---: | | [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | | [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | -| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02` | +| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | | [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | | levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | | maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` | diff --git a/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh index fce89244e..a5eaf9aa0 100644 --- a/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh +++ b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh @@ -1 +1,2 @@ -export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_disable_async_collectives=allreduce,allgather,reducescatter,collectivebroadcast,alltoall,collectivepermute ${XLA_FLAGS}" +# These XLA flags are meant to be used with the JAX version in the imagen container +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}" diff --git a/rosetta/rosetta/projects/imagen/README.md b/rosetta/rosetta/projects/imagen/README.md index 4959a4118..136f913d6 100644 --- a/rosetta/rosetta/projects/imagen/README.md +++ b/rosetta/rosetta/projects/imagen/README.md @@ -17,7 +17,7 @@ For maximum flexibility and low disk requirements, this repo supports a **distri We provide [scripts](scripts) to run [interactively](scripts/singlenode_inf_train.sh) or on [SLURM](scripts/example_slurm_inf_train.sub). ### Container -We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02`. +We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3`. We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates! Imagen will also be available in our T5x container in future releases. @@ -37,7 +37,7 @@ You will need to acquire the LLM checkpoint for T5 (for multimodal training) fro **Note**: this should only be done with singlenode jobs ```bash -CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02 +CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 docker run --rm --gpus=all -it --net=host --ipc=host -v ${PWD}:/opt/rosetta -v ${DATASET_PATH}:/mnt/datasets --privileged $CONTAINER bash ``` @@ -99,15 +99,27 @@ sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \ You can find example sampling scripts that use the 500M base model and EfficientUnet SR models in [scripts](scripts). Prompts should be specified as in [example](../diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt) #### Sampling 256x256 images -Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script) +Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script, e.g., [imagen_256_sample_2b.gin](configs/imagen_256_sample_2b.gin)). ``` -CUDA_VISIBLE_DEVICES= CFG=5.0 BASE_PATH= SR1_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_256.sh +CUDA_VISIBLE_DEVICES= CFG=5.0 GLOBAL_BATCH_SIZE= GEN_PER_PROMPT=1 BASE_PATH= SR1_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_256.sh +``` + +Here is an example: +``` +# Note: +# - the quoting of double quotes wrapping single quotes is necessary. +# - BASE_PATH/SR1_PATH are checkpoint dirs, and are expected to contain a `checkpoint` file, e.g., the file $BASE_PATH/checkpoint should exist +# - GLOBAL_BATCH_SIZE should be set with number of GPUs in mind. For instance GLOBAL_BATCH_SIZE >= num gpus, +# to ensure at least one example is sent to each GPU. +# - Currently there is a limitation where the number of lines in PROMPT_TEXT_FILES should be divisible by the number of GPUs. +# The easiest way to ensure that is just to pad the files with dummy prompts until it is divisible +CUDA_VISIBLE_DEVICES=0,1 CFG=5.0 GLOBAL_BATCH_SIZE=4 GEN_PER_PROMPT=1 BASE_PATH='"/mnt/imagen_ckpt/checkpoint_585000"' SR1_PATH='"/mnt/sr1_ckpt/checkpoint_5000"' PROMPT_TEXT_FILES='"./rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt"' ./rosetta/projects/imagen/scripts/sample_imagen_256.sh ``` #### Sampling 1024x1024 images -Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script). +Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script, e.g., [imagen_1024_sample_2b.gin](configs/imagen_1024_sample_2b.gin)). ``` -CUDA_VISIBLE_DEVICES= CFG=5.0 BASE_PATH= SR1_PATH= SR2_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh +CUDA_VISIBLE_DEVICES= CFG=5.0 GLOBAL_BATCH_SIZE= GEN_PER_PROMPT=1 BASE_PATH= SR1_PATH= SR2_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh ``` diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin b/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin new file mode 100644 index 000000000..101e2773e --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin @@ -0,0 +1,78 @@ +# Imagen Sampling pipeline +include "rosetta/projects/imagen/configs/imagen_256_sample_2b.gin" + +from __gin__ import dynamic_registration +import __main__ as sample_script +from t5x import gin_utils +from t5x import utils +from t5x import partitioning + +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +#---------------- SR1024 Model ------------------------------------------------- + +# ------------------- Model ---------------------------------------------------- +SR1024 = @sr1024/models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +sr1024/models.DenoisingDiffusionModel: + denoiser= @sr1024/denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= None + diffusion_sampler= @sr1024/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +sr1024/denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @sr1024/network_sr.ImagenEfficientUNet() + +sr1024/samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +sr1024/network_sr.ImagenEfficientUNet.config = @sr1024/network_sr.ImagenEfficientUNetConfig() +sr1024/network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 8) + width_multipliers = (1, 2, 4, 6, 6) + attn_resolutions_divs = {16: 'cross'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True + +sr1024/samplers.CFGSamplingConfig: + num_steps=30 + cf_guidance_weight=0.0 + cf_guidance_nulls={'text': None, 'text_mask': None} + +sr1024/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +sr1024/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +sr1024/sample_script.DiffusionModelSetupData: + model = %SR1024 + sampling_cfg = @sr1024/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @sr1024/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 1024, 1024, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 256, 256, 3)} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} + +sample_script.sample: + sr1024_setupdata = @sr1024/sample_script.DiffusionModelSetupData() diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin b/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin new file mode 100644 index 000000000..13272ebbe --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin @@ -0,0 +1,220 @@ +# Imagen Sampling pipeline +from __gin__ import dynamic_registration + +import __main__ as sample_script +from t5x import gin_utils +from t5x import utils +from t5x import partitioning + +SAVE_DIR='generations' +PROMPT_TEXT_FILE='custom_text.txt' +GLOBAL_BATCH_SIZE=32 +MAX_GENERATE=50000000 +GEN_PER_PROMPT=2 +NOISE_COND_AUG=0.002 + +TXT_SHAPE=(1, 128, 4096) #T5 xxl, seqlen x embed_dim +TXT_SEQLEN=(1, 128, ) +TXT_SEQLEN_SINGLE=128 +DTYPE='bfloat16' +DROPOUT_RATE=0 +RESUME_FROM=0 #Sampling count to resume from +#---------------- Base Model ------------------------------------------------- +from rosetta.projects.imagen import network +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +# ------------------- Model ---------------------------------------------------- +BASE = @base_model/models.DenoisingDiffusionModel() +base_model/models.DenoisingDiffusionModel: + denoiser= @base_model/denoisers.EDMTextConditionedDenoiser() + diffusion_loss = None + diffusion_sampler= @base_model/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +base_model/denoisers.EDMTextConditionedDenoiser: + raw_model= @base_model/network.ImagenUNet() + +# ------------------- Network specification ------------------------------------ +base_model/network.ImagenUNet.config = @base_model/network.DiffusionConfig() +base_model/network.DiffusionConfig: + dtype = %DTYPE + model_dim = 512 + attn_cond_dim = 2048 + cond_dim = 2048 + resblocks_per_level = 3 + width_multipliers = (1, 2, 3, 4) + attn_resolutions = (32, 16, 8) + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + dropout_rate = %DROPOUT_RATE + upsample_mode = 'shuffle' + downsample_mode = 'shuffle' + spatial_skip = False + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits = False + text_conditionable = True + + +BASE_SAMPLING_CONFIG = @base_model/samplers.CFGSamplingConfig() +base_model/samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=5.00 + cf_guidance_nulls=None + +base_model/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +base_model/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +base_model/sample_script.DiffusionModelSetupData: + model = %BASE + sampling_cfg = @base_model/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @base_model/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 64, 64, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int'} + +#---------------- SR256 Model ------------------------------------------------- + +# ------------------- Model ---------------------------------------------------- +SR256 = @sr256/models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +sr256/models.DenoisingDiffusionModel: + denoiser= @sr256/denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= None + diffusion_sampler= @sr256/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +sr256/denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @sr256/network_sr.ImagenEfficientUNet() + +sr256/samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +sr256/network_sr.ImagenEfficientUNet.config = @sr256/network_sr.ImagenEfficientUNetConfig() +sr256/network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 512 + attn_cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 2) + width_multipliers = (1, 2, 4, 8, 8) + attn_resolutions_divs = {8: 'fused', 16: 'fused'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True + +sr256/samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=4 + cf_guidance_nulls={'text': None, 'text_mask': None} + +sr256/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +sr256/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +sr256/sample_script.DiffusionModelSetupData: + model = %SR256 + sampling_cfg = @sr256/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @sr256/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 256, 256, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 64, 64, 3)} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} + +#---------------- Text Model ------------------------------------------------- +import seqio +from rosetta.projects.inference_serving.t5 import network as t5x_network +from rosetta.projects.inference_serving.t5 import models as t5x_models + +# ===================================== +# === T5 Encoder only configuration === +# ===================================== +T5_CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl" +BATCH_SIZE = 256 # Will be overridden +SEQ_LEN = 128 # MAX seqlen + +# Vocabulary +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" +TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. + +# --------------- Model ------------------ +TEXT_ENC = @text_enc/t5x_models.EncoderOnlyModel() +text_enc/t5x_models.EncoderOnlyModel: + module = @t5x_network.TransformerEncoderOnly() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = None + z_loss = 0.0001 + label_smoothing = 0.0 + loss_normalizing_factor = None + +# -------- Network specification --------- +t5x_network.TransformerEncoderOnly.config = @t5x_network.T5Config() +t5x_network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 4096 + num_heads = 64 + num_encoder_layers = 24 + num_decoder_layers = 0 + head_dim = 64 + mlp_dim = 10240 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + +text_enc/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +text_enc/utils.RestoreCheckpointConfig: + path = %T5_CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' + +text_enc/sample_script.setup_text_enc: + model=%TEXT_ENC + restore_checkpoint_cfg=@text_enc/utils.RestoreCheckpointConfig() + partitioner=@text_enc/partitioning.PjitPartitioner() + batch_size=1 + seq_len=%TXT_SEQLEN_SINGLE + vocab = %VOCABULARY + +sample_script.sample: + base_setupdata = @base_model/sample_script.DiffusionModelSetupData() + sr256_setupdata = @sr256/sample_script.DiffusionModelSetupData() + sr1024_setupdata = None + out_dir = %SAVE_DIR + gen_per_prompt = %GEN_PER_PROMPT + prompt_file = %PROMPT_TEXT_FILE + batch_size = %GLOBAL_BATCH_SIZE + max_images = %MAX_GENERATE + text_enc_infer = @text_enc/sample_script.setup_text_enc() + noise_conditioning_aug = %NOISE_COND_AUG + resume_from = %RESUME_FROM diff --git a/rosetta/rosetta/projects/imagen/imagen_pipe.py b/rosetta/rosetta/projects/imagen/imagen_pipe.py index fb96d7ffd..8995234e4 100644 --- a/rosetta/rosetta/projects/imagen/imagen_pipe.py +++ b/rosetta/rosetta/projects/imagen/imagen_pipe.py @@ -19,6 +19,7 @@ import functools from typing import Mapping, Any, Optional, Callable, Sequence import logging +import time import numpy as np import jax @@ -37,6 +38,8 @@ _DEFAULT_GIN_SEARCH_PATHS = [ os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ] +# This prevents issues where filenames go longer than the max length allowed in unix +MAX_FILENAME_LENGTH = 150 @dataclasses.dataclass class DiffusionModelSetupData: @@ -194,8 +197,9 @@ def sample( sampled_ctr = 0 rng = jax.random.PRNGKey(0) + start_time = time.time() for start_idx in range(resume_from, max_images, batch_size // gen_per_prompt): - if start_idx > prompt_ct: + if start_idx >= prompt_ct: break prompt_batch = prompts[start_idx: start_idx + (batch_size // gen_per_prompt)] * gen_per_prompt rng, rng_base, rng_sr, rng_sr2, rng_aug = jax.random.split(rng, 5) @@ -213,7 +217,7 @@ def sample( base_batch = {'samples': base_img_inputs, 'text': encoded_text, 'text_mask': text_mask} base_out = base_fn(base_params, base_batch, rng_base) for i in range(base_out.shape[0]): - matimg.imsave(os.path.join(base_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), np.clip(base_out[i], a_min=0, a_max=1)) + matimg.imsave(os.path.join(base_dir, sanitize_filename(f'{prompt_batch[i][:MAX_FILENAME_LENGTH]}_{sampled_ctr + i}.png')), np.clip(base_out[i], a_min=0, a_max=1)) # Stage 2: Super Resolution (64-> 256) base_aug = (base_out * 2 - 1) @@ -222,7 +226,7 @@ def sample( sr_out = sr256_fn(sr256_params, sr256_batch, rng_sr) sr_out = jnp.clip(sr_out, a_min = 0, a_max = 1) for i in range(sr_out.shape[0]): - matimg.imsave(os.path.join(sr_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), sr_out[i]) + matimg.imsave(os.path.join(sr_dir, sanitize_filename(f'{prompt_batch[i][:MAX_FILENAME_LENGTH]}_{sampled_ctr + i}.png')), sr_out[i]) # Stage 3: Super Resolution (256-> 1024) if sr1024_setupdata is not None: @@ -232,9 +236,12 @@ def sample( sr_out = sr1024_fn(sr1024_params, sr1024_batch, rng_sr2) sr_out = jnp.clip(sr_out, a_min = 0, a_max = 1) for i in range(sr_out.shape[0]): - matimg.imsave(os.path.join(sr2_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), sr_out[i]) + matimg.imsave(os.path.join(sr2_dir, sanitize_filename(f'{prompt_batch[i][:MAX_FILENAME_LENGTH]}_{sampled_ctr + i}.png')), sr_out[i]) sampled_ctr += sr_out.shape[0] + print(f'total samples generated={sampled_ctr}') + print(f'batch sec/sample={(time.time() - start_time) / sr_out.shape[0]}') + start_time = time.time() if __name__ == '__main__': diff --git a/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub b/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub index 6da0ccde1..4beb45afb 100755 --- a/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub +++ b/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub @@ -14,11 +14,9 @@ set -x # File system and volume glue code #------------------------------------------------------------------------------- # << CHANGE ! >> -SLURM_ACCOUNT= -USERID= # << CHANGE ! >> -CONTAINER= +CONTAINER=${CONTAINER:-ghcr.io#nvidia/t5x:imagen-2023-10-02.v3} # << CHANGE ! >> BASE_ROSETTA_DIR="/jax-toolbox-mirror/rosetta/" # path to your clone of the repo diff --git a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh index d4d9ce63d..afd1f48b6 100755 --- a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh +++ b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. CFG=${CFG:=2} +GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-4} +GEN_PER_PROMPT=${GEN_PER_PROMPT:-1} +SAMPLING_GIN=${SAMPLING_GIN:-/opt/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample.gin} BASE_PATH=${BASE_PATH:=\"/opt/rosetta/runs/imagen_base/checkpoint_5000\"} SR1_PATH=${SR1_PATH:=\"/opt/rosetta/runs/efficient_sr1/checkpoint_5000\"} SR2_PATH=${SR1_PATH:=\"/opt/rosetta/runs/efficient_sr2/checkpoint_5000\"} @@ -20,14 +23,15 @@ PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE:=\"/opt/rosetta/rosetta/projects/diffusion/t export DISABLE_TE=True python /opt/rosetta/rosetta/projects/imagen/imagen_pipe.py \ - --gin_file="/opt/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample.gin" \ + --gin_file="${SAMPLING_GIN}" \ --gin.base_model/utils.RestoreCheckpointConfig.path="${BASE_PATH}" \ --gin.sr256/utils.RestoreCheckpointConfig.path="${SR1_PATH}" \ --gin.sr1024/utils.RestoreCheckpointConfig.path="${SR2_PATH}" \ --gin.T5_CHECKPOINT_PATH="\"/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl\"" \ --gin.base_model/samplers.CFGSamplingConfig.cf_guidance_weight=${CFG} \ --gin.PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE} \ - --gin.GLOBAL_BATCH_SIZE=4 \ + --gin.GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} \ --gin.SAVE_DIR="\"generations/generations-${CFG}\"" \ - --gin.GEN_PER_PROMPT=1 \ - --gin.RESUME_FROM=0 \ No newline at end of file + --gin.GEN_PER_PROMPT=${GEN_PER_PROMPT} \ + --gin.RESUME_FROM=0 \ + $@ diff --git a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh index 19d19be73..4dc46fe87 100755 --- a/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh +++ b/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh @@ -13,19 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. CFG=${CFG:=2} +GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-4} +GEN_PER_PROMPT=${GEN_PER_PROMPT:-1} +SAMPLING_GIN=${SAMPLING_GIN:-/opt/rosetta/rosetta/projects/imagen/configs/imagen_256_sample.gin} BASE_PATH=${BASE_PATH:=\"/opt/rosetta/runs/imagen_base/checkpoint_5000\"} SR1_PATH=${SR1_PATH:=\"/opt/rosetta/runs/efficient_sr1/checkpoint_5000\"} PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE:=\"/opt/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt\"} export DISABLE_TE=True python /opt/rosetta/rosetta/projects/imagen/imagen_pipe.py \ - --gin_file="/opt/rosetta/rosetta/projects/imagen/configs/imagen_256_sample.gin" \ + --gin_file="${SAMPLING_GIN}" \ --gin.base_model/utils.RestoreCheckpointConfig.path="${BASE_PATH}" \ --gin.sr256/utils.RestoreCheckpointConfig.path="${SR1_PATH}" \ --gin.T5_CHECKPOINT_PATH="\"/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl\"" \ --gin.base_model/samplers.CFGSamplingConfig.cf_guidance_weight=${CFG} \ --gin.PROMPT_TEXT_FILE=${PROMPT_TEXT_FILE}\ - --gin.GLOBAL_BATCH_SIZE=4 \ + --gin.GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} \ --gin.SAVE_DIR="\"generations/generations-${CFG}\"" \ - --gin.GEN_PER_PROMPT=1 \ - --gin.RESUME_FROM=0 \ No newline at end of file + --gin.GEN_PER_PROMPT=${GEN_PER_PROMPT} \ + --gin.RESUME_FROM=0 \ + $@ From 401f80cd6d4a3a3451a60e38c1cb8e05b8f1698a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= Date: Tue, 27 Aug 2024 16:35:33 -0400 Subject: [PATCH 2/2] Remove the flag xla_gpu_enable_highest_priority_async_stream=true (#1006) that was enabled by default in April. --- .github/container/test-maxtext.sh | 3 +-- rosetta/docs/NATIVE_FP8.md | 1 - rosetta/docs/PGLE.md | 1 - rosetta/rosetta/projects/maxtext/README.md | 1 - rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub | 1 - rosetta/rosetta/projects/pax/README.md | 3 +-- 6 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 21591c91c..164fa5912 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -226,7 +226,6 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 - --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -268,4 +267,4 @@ fi echo "Command: python3 $RUN_SETTINGS" python3 $RUN_SETTINGS -echo "Output at ${OUTPUT}" \ No newline at end of file +echo "Output at ${OUTPUT}" diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index cb26ea2df..dd3aa1bae 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -117,7 +117,6 @@ export XLA_FLAGS=" \ --xla_gpu_enable_cudnn_layer_norm=true \ --xla_gpu_enable_cublaslt=true \ --xla_gpu_enable_latency_hiding_scheduler=true \ - --xla_gpu_enable_highest_priority_async_stream=true \ --xla_gpu_all_reduce_combine_threshold_bytes=51200 " export ENABLE_TE=0 python -m paxml.main \ diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index 86882dd37..02e5f5294 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -63,7 +63,6 @@ In order to get the best performance with PGLE, here is a list of all recommende export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 ---xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824 diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index 8486ee566..fde5a9125 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -73,7 +73,6 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true - --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index 45cb5da2c..e96eaa781 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -58,7 +58,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true - --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index a2dbd1cf0..6ac4dc150 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -139,8 +139,7 @@ For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_thres ``` BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_highest_priority_async_stream=true + --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ... ```