Skip to content

Commit

Permalink
merged with upstream on Aug 27 6:18PM
Browse files Browse the repository at this point in the history
  • Loading branch information
kocchop committed Aug 28, 2024
2 parents c4d0f78 + 401f80c commit 3194649
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 25 deletions.
4 changes: 2 additions & 2 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_graph_level=0
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728
Expand Down Expand Up @@ -266,4 +266,4 @@ fi
echo "Command: python3 $RUN_SETTINGS"
python3 $RUN_SETTINGS

echo "Output at ${OUTPUT}"
echo "Output at ${OUTPUT}"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ We currently support the following frameworks and models. More details about eac
| :--- | :---: | :---: | :---: |
| [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` |
| [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` |
| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02` |
| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` |
| [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` |
| maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` |
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_disable_async_collectives=allreduce,allgather,reducescatter,collectivebroadcast,alltoall,collectivepermute ${XLA_FLAGS}"
# These XLA flags are meant to be used with the JAX version in the imagen container
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}"
24 changes: 18 additions & 6 deletions rosetta/rosetta/projects/imagen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ For maximum flexibility and low disk requirements, this repo supports a **distri
We provide [scripts](scripts) to run [interactively](scripts/singlenode_inf_train.sh) or on [SLURM](scripts/example_slurm_inf_train.sub).

### Container
We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02`.
We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3`.

We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates!
Imagen will also be available in our T5x container in future releases.
Expand All @@ -37,7 +37,7 @@ You will need to acquire the LLM checkpoint for T5 (for multimodal training) fro
**Note**: this should only be done with singlenode jobs

```bash
CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02
CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02.v3
docker run --rm --gpus=all -it --net=host --ipc=host -v ${PWD}:/opt/rosetta -v ${DATASET_PATH}:/mnt/datasets --privileged $CONTAINER bash
```

Expand Down Expand Up @@ -99,15 +99,27 @@ sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \
You can find example sampling scripts that use the 500M base model and EfficientUnet SR models in [scripts](scripts). Prompts should be specified as in [example](../diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt)

#### Sampling 256x256 images
Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script)
Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script, e.g., [imagen_256_sample_2b.gin](configs/imagen_256_sample_2b.gin)).
```
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_256.sh
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 GLOBAL_BATCH_SIZE=<GBS> GEN_PER_PROMPT=1 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_256.sh
```

Here is an example:
```
# Note:
# - the quoting of double quotes wrapping single quotes is necessary.
# - BASE_PATH/SR1_PATH are checkpoint dirs, and are expected to contain a `checkpoint` file, e.g., the file $BASE_PATH/checkpoint should exist
# - GLOBAL_BATCH_SIZE should be set with number of GPUs in mind. For instance GLOBAL_BATCH_SIZE >= num gpus,
# to ensure at least one example is sent to each GPU.
# - Currently there is a limitation where the number of lines in PROMPT_TEXT_FILES should be divisible by the number of GPUs.
# The easiest way to ensure that is just to pad the files with dummy prompts until it is divisible
CUDA_VISIBLE_DEVICES=0,1 CFG=5.0 GLOBAL_BATCH_SIZE=4 GEN_PER_PROMPT=1 BASE_PATH='"/mnt/imagen_ckpt/checkpoint_585000"' SR1_PATH='"/mnt/sr1_ckpt/checkpoint_5000"' PROMPT_TEXT_FILES='"./rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt"' ./rosetta/projects/imagen/scripts/sample_imagen_256.sh
```

#### Sampling 1024x1024 images
Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script).
Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script, e.g., [imagen_1024_sample_2b.gin](configs/imagen_1024_sample_2b.gin)).
```
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> SR2_PATH=<SR2_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh
CUDA_VISIBLE_DEVICES=<DEVICES> CFG=5.0 GLOBAL_BATCH_SIZE=<GBS> GEN_PER_PROMPT=1 BASE_PATH=<BASE_CKPT> SR1_PATH=<SR1_CKPT> SR2_PATH=<SR2_CKPT> PROMPT_TEXT_FILES=<FILE> ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh
```


Expand Down
78 changes: 78 additions & 0 deletions rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Imagen Sampling pipeline
include "rosetta/projects/imagen/configs/imagen_256_sample_2b.gin"

from __gin__ import dynamic_registration
import __main__ as sample_script
from t5x import gin_utils
from t5x import utils
from t5x import partitioning

from rosetta.projects.imagen import network_sr
from rosetta.projects.diffusion import models
from rosetta.projects.diffusion import denoisers
from rosetta.projects.diffusion import samplers
from rosetta.projects.diffusion import losses
from rosetta.projects.diffusion import augmentations

#---------------- SR1024 Model -------------------------------------------------

# ------------------- Model ----------------------------------------------------
SR1024 = @sr1024/models.DenoisingDiffusionModel()
SIGMA_DATA = 0.5
sr1024/models.DenoisingDiffusionModel:
denoiser= @sr1024/denoisers.EDMTextConditionedSuperResDenoiser()
diffusion_loss= None
diffusion_sampler= @sr1024/samplers.EDMSampler()
optimizer_def = None

# |--- Denoiser
sr1024/denoisers.EDMTextConditionedSuperResDenoiser:
raw_model= @sr1024/network_sr.ImagenEfficientUNet()

sr1024/samplers.EDMSampler:
dim_noise_scalar = 4.

# ------------------- Network specification ------------------------------------
sr1024/network_sr.ImagenEfficientUNet.config = @sr1024/network_sr.ImagenEfficientUNetConfig()
sr1024/network_sr.ImagenEfficientUNetConfig:
dtype = %DTYPE
model_dim = 128
cond_dim = 1024
resblocks_per_level = (2, 4, 8, 8, 8)
width_multipliers = (1, 2, 4, 6, 6)
attn_resolutions_divs = {16: 'cross'}
mha_head_dim = 64
attn_heads = 8
resblock_activation = 'silu'
resblock_zero_out = True
resblock_scale_skip = True
dropout_rate = %DROPOUT_RATE
cond_strategy = 'shift_scale'
norm_32 = True
scale_attn_logits = True
float32_attention_logits=False
text_conditionable = True

sr1024/samplers.CFGSamplingConfig:
num_steps=30
cf_guidance_weight=0.0
cf_guidance_nulls={'text': None, 'text_mask': None}

sr1024/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

sr1024/utils.RestoreCheckpointConfig:
mode = 'specific'
dtype = 'bfloat16'

sr1024/sample_script.DiffusionModelSetupData:
model = %SR1024
sampling_cfg = @sr1024/samplers.CFGSamplingConfig()
restore_checkpoint_cfg = @sr1024/utils.RestoreCheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
input_shapes = {'samples': (1, 1024, 1024, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 256, 256, 3)}
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'}

sample_script.sample:
sr1024_setupdata = @sr1024/sample_script.DiffusionModelSetupData()
220 changes: 220 additions & 0 deletions rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Imagen Sampling pipeline
from __gin__ import dynamic_registration

import __main__ as sample_script
from t5x import gin_utils
from t5x import utils
from t5x import partitioning

SAVE_DIR='generations'
PROMPT_TEXT_FILE='custom_text.txt'
GLOBAL_BATCH_SIZE=32
MAX_GENERATE=50000000
GEN_PER_PROMPT=2
NOISE_COND_AUG=0.002

TXT_SHAPE=(1, 128, 4096) #T5 xxl, seqlen x embed_dim
TXT_SEQLEN=(1, 128, )
TXT_SEQLEN_SINGLE=128
DTYPE='bfloat16'
DROPOUT_RATE=0
RESUME_FROM=0 #Sampling count to resume from
#---------------- Base Model -------------------------------------------------
from rosetta.projects.imagen import network
from rosetta.projects.imagen import network_sr
from rosetta.projects.diffusion import models
from rosetta.projects.diffusion import denoisers
from rosetta.projects.diffusion import samplers
from rosetta.projects.diffusion import losses
from rosetta.projects.diffusion import augmentations

# ------------------- Model ----------------------------------------------------
BASE = @base_model/models.DenoisingDiffusionModel()
base_model/models.DenoisingDiffusionModel:
denoiser= @base_model/denoisers.EDMTextConditionedDenoiser()
diffusion_loss = None
diffusion_sampler= @base_model/samplers.EDMSampler()
optimizer_def = None

# |--- Denoiser
base_model/denoisers.EDMTextConditionedDenoiser:
raw_model= @base_model/network.ImagenUNet()

# ------------------- Network specification ------------------------------------
base_model/network.ImagenUNet.config = @base_model/network.DiffusionConfig()
base_model/network.DiffusionConfig:
dtype = %DTYPE
model_dim = 512
attn_cond_dim = 2048
cond_dim = 2048
resblocks_per_level = 3
width_multipliers = (1, 2, 3, 4)
attn_resolutions = (32, 16, 8)
mha_head_dim = 64
attn_heads = 8
resblock_activation = 'silu'
dropout_rate = %DROPOUT_RATE
upsample_mode = 'shuffle'
downsample_mode = 'shuffle'
spatial_skip = False
cond_strategy = 'shift_scale'
norm_32 = True
scale_attn_logits = True
float32_attention_logits = False
text_conditionable = True


BASE_SAMPLING_CONFIG = @base_model/samplers.CFGSamplingConfig()
base_model/samplers.CFGSamplingConfig:
num_steps=50
cf_guidance_weight=5.00
cf_guidance_nulls=None

base_model/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

base_model/utils.RestoreCheckpointConfig:
mode = 'specific'
dtype = 'bfloat16'

base_model/sample_script.DiffusionModelSetupData:
model = %BASE
sampling_cfg = @base_model/samplers.CFGSamplingConfig()
restore_checkpoint_cfg = @base_model/utils.RestoreCheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
input_shapes = {'samples': (1, 64, 64, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN}
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int'}

#---------------- SR256 Model -------------------------------------------------

# ------------------- Model ----------------------------------------------------
SR256 = @sr256/models.DenoisingDiffusionModel()
SIGMA_DATA = 0.5
sr256/models.DenoisingDiffusionModel:
denoiser= @sr256/denoisers.EDMTextConditionedSuperResDenoiser()
diffusion_loss= None
diffusion_sampler= @sr256/samplers.EDMSampler()
optimizer_def = None

# |--- Denoiser
sr256/denoisers.EDMTextConditionedSuperResDenoiser:
raw_model= @sr256/network_sr.ImagenEfficientUNet()

sr256/samplers.EDMSampler:
dim_noise_scalar = 4.

# ------------------- Network specification ------------------------------------
sr256/network_sr.ImagenEfficientUNet.config = @sr256/network_sr.ImagenEfficientUNetConfig()
sr256/network_sr.ImagenEfficientUNetConfig:
dtype = %DTYPE
model_dim = 128
cond_dim = 512
attn_cond_dim = 1024
resblocks_per_level = (2, 4, 8, 8, 2)
width_multipliers = (1, 2, 4, 8, 8)
attn_resolutions_divs = {8: 'fused', 16: 'fused'}
mha_head_dim = 64
attn_heads = 8
resblock_activation = 'silu'
resblock_zero_out = True
resblock_scale_skip = True
dropout_rate = %DROPOUT_RATE
cond_strategy = 'shift_scale'
norm_32 = True
scale_attn_logits = True
float32_attention_logits=False
text_conditionable = True

sr256/samplers.CFGSamplingConfig:
num_steps=50
cf_guidance_weight=4
cf_guidance_nulls={'text': None, 'text_mask': None}

sr256/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

sr256/utils.RestoreCheckpointConfig:
mode = 'specific'
dtype = 'bfloat16'

sr256/sample_script.DiffusionModelSetupData:
model = %SR256
sampling_cfg = @sr256/samplers.CFGSamplingConfig()
restore_checkpoint_cfg = @sr256/utils.RestoreCheckpointConfig()
partitioner = @partitioning.PjitPartitioner()
input_shapes = {'samples': (1, 256, 256, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 64, 64, 3)}
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'}

#---------------- Text Model -------------------------------------------------
import seqio
from rosetta.projects.inference_serving.t5 import network as t5x_network
from rosetta.projects.inference_serving.t5 import models as t5x_models

# =====================================
# === T5 Encoder only configuration ===
# =====================================
T5_CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl"
BATCH_SIZE = 256 # Will be overridden
SEQ_LEN = 128 # MAX seqlen

# Vocabulary
VOCABULARY = @seqio.SentencePieceVocabulary()
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use.

# --------------- Model ------------------
TEXT_ENC = @text_enc/t5x_models.EncoderOnlyModel()
text_enc/t5x_models.EncoderOnlyModel:
module = @t5x_network.TransformerEncoderOnly()
input_vocabulary = %VOCABULARY
output_vocabulary = %VOCABULARY
optimizer_def = None
z_loss = 0.0001
label_smoothing = 0.0
loss_normalizing_factor = None

# -------- Network specification ---------
t5x_network.TransformerEncoderOnly.config = @t5x_network.T5Config()
t5x_network.T5Config:
vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
dtype = 'bfloat16'
emb_dim = 4096
num_heads = 64
num_encoder_layers = 24
num_decoder_layers = 0
head_dim = 64
mlp_dim = 10240
mlp_activations = ('gelu', 'linear')
dropout_rate = 0.0

text_enc/partitioning.PjitPartitioner:
num_partitions = 1
logical_axis_rules = @partitioning.standard_logical_axis_rules()

text_enc/utils.RestoreCheckpointConfig:
path = %T5_CHECKPOINT_PATH
mode = 'specific'
dtype = 'bfloat16'

text_enc/sample_script.setup_text_enc:
model=%TEXT_ENC
restore_checkpoint_cfg=@text_enc/utils.RestoreCheckpointConfig()
partitioner=@text_enc/partitioning.PjitPartitioner()
batch_size=1
seq_len=%TXT_SEQLEN_SINGLE
vocab = %VOCABULARY

sample_script.sample:
base_setupdata = @base_model/sample_script.DiffusionModelSetupData()
sr256_setupdata = @sr256/sample_script.DiffusionModelSetupData()
sr1024_setupdata = None
out_dir = %SAVE_DIR
gen_per_prompt = %GEN_PER_PROMPT
prompt_file = %PROMPT_TEXT_FILE
batch_size = %GLOBAL_BATCH_SIZE
max_images = %MAX_GENERATE
text_enc_infer = @text_enc/sample_script.setup_text_enc()
noise_conditioning_aug = %NOISE_COND_AUG
resume_from = %RESUME_FROM
Loading

0 comments on commit 3194649

Please sign in to comment.