-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merged with upstream on Aug 27 6:18PM
- Loading branch information
Showing
10 changed files
with
349 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 2 additions & 1 deletion
3
rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_disable_async_collectives=allreduce,allgather,reducescatter,collectivebroadcast,alltoall,collectivepermute ${XLA_FLAGS}" | ||
# These XLA flags are meant to be used with the JAX version in the imagen container | ||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Imagen Sampling pipeline | ||
include "rosetta/projects/imagen/configs/imagen_256_sample_2b.gin" | ||
|
||
from __gin__ import dynamic_registration | ||
import __main__ as sample_script | ||
from t5x import gin_utils | ||
from t5x import utils | ||
from t5x import partitioning | ||
|
||
from rosetta.projects.imagen import network_sr | ||
from rosetta.projects.diffusion import models | ||
from rosetta.projects.diffusion import denoisers | ||
from rosetta.projects.diffusion import samplers | ||
from rosetta.projects.diffusion import losses | ||
from rosetta.projects.diffusion import augmentations | ||
|
||
#---------------- SR1024 Model ------------------------------------------------- | ||
|
||
# ------------------- Model ---------------------------------------------------- | ||
SR1024 = @sr1024/models.DenoisingDiffusionModel() | ||
SIGMA_DATA = 0.5 | ||
sr1024/models.DenoisingDiffusionModel: | ||
denoiser= @sr1024/denoisers.EDMTextConditionedSuperResDenoiser() | ||
diffusion_loss= None | ||
diffusion_sampler= @sr1024/samplers.EDMSampler() | ||
optimizer_def = None | ||
|
||
# |--- Denoiser | ||
sr1024/denoisers.EDMTextConditionedSuperResDenoiser: | ||
raw_model= @sr1024/network_sr.ImagenEfficientUNet() | ||
|
||
sr1024/samplers.EDMSampler: | ||
dim_noise_scalar = 4. | ||
|
||
# ------------------- Network specification ------------------------------------ | ||
sr1024/network_sr.ImagenEfficientUNet.config = @sr1024/network_sr.ImagenEfficientUNetConfig() | ||
sr1024/network_sr.ImagenEfficientUNetConfig: | ||
dtype = %DTYPE | ||
model_dim = 128 | ||
cond_dim = 1024 | ||
resblocks_per_level = (2, 4, 8, 8, 8) | ||
width_multipliers = (1, 2, 4, 6, 6) | ||
attn_resolutions_divs = {16: 'cross'} | ||
mha_head_dim = 64 | ||
attn_heads = 8 | ||
resblock_activation = 'silu' | ||
resblock_zero_out = True | ||
resblock_scale_skip = True | ||
dropout_rate = %DROPOUT_RATE | ||
cond_strategy = 'shift_scale' | ||
norm_32 = True | ||
scale_attn_logits = True | ||
float32_attention_logits=False | ||
text_conditionable = True | ||
|
||
sr1024/samplers.CFGSamplingConfig: | ||
num_steps=30 | ||
cf_guidance_weight=0.0 | ||
cf_guidance_nulls={'text': None, 'text_mask': None} | ||
|
||
sr1024/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
sr1024/utils.RestoreCheckpointConfig: | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
sr1024/sample_script.DiffusionModelSetupData: | ||
model = %SR1024 | ||
sampling_cfg = @sr1024/samplers.CFGSamplingConfig() | ||
restore_checkpoint_cfg = @sr1024/utils.RestoreCheckpointConfig() | ||
partitioner = @partitioning.PjitPartitioner() | ||
input_shapes = {'samples': (1, 1024, 1024, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 256, 256, 3)} | ||
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} | ||
|
||
sample_script.sample: | ||
sr1024_setupdata = @sr1024/sample_script.DiffusionModelSetupData() |
220 changes: 220 additions & 0 deletions
220
rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Imagen Sampling pipeline | ||
from __gin__ import dynamic_registration | ||
|
||
import __main__ as sample_script | ||
from t5x import gin_utils | ||
from t5x import utils | ||
from t5x import partitioning | ||
|
||
SAVE_DIR='generations' | ||
PROMPT_TEXT_FILE='custom_text.txt' | ||
GLOBAL_BATCH_SIZE=32 | ||
MAX_GENERATE=50000000 | ||
GEN_PER_PROMPT=2 | ||
NOISE_COND_AUG=0.002 | ||
|
||
TXT_SHAPE=(1, 128, 4096) #T5 xxl, seqlen x embed_dim | ||
TXT_SEQLEN=(1, 128, ) | ||
TXT_SEQLEN_SINGLE=128 | ||
DTYPE='bfloat16' | ||
DROPOUT_RATE=0 | ||
RESUME_FROM=0 #Sampling count to resume from | ||
#---------------- Base Model ------------------------------------------------- | ||
from rosetta.projects.imagen import network | ||
from rosetta.projects.imagen import network_sr | ||
from rosetta.projects.diffusion import models | ||
from rosetta.projects.diffusion import denoisers | ||
from rosetta.projects.diffusion import samplers | ||
from rosetta.projects.diffusion import losses | ||
from rosetta.projects.diffusion import augmentations | ||
|
||
# ------------------- Model ---------------------------------------------------- | ||
BASE = @base_model/models.DenoisingDiffusionModel() | ||
base_model/models.DenoisingDiffusionModel: | ||
denoiser= @base_model/denoisers.EDMTextConditionedDenoiser() | ||
diffusion_loss = None | ||
diffusion_sampler= @base_model/samplers.EDMSampler() | ||
optimizer_def = None | ||
|
||
# |--- Denoiser | ||
base_model/denoisers.EDMTextConditionedDenoiser: | ||
raw_model= @base_model/network.ImagenUNet() | ||
|
||
# ------------------- Network specification ------------------------------------ | ||
base_model/network.ImagenUNet.config = @base_model/network.DiffusionConfig() | ||
base_model/network.DiffusionConfig: | ||
dtype = %DTYPE | ||
model_dim = 512 | ||
attn_cond_dim = 2048 | ||
cond_dim = 2048 | ||
resblocks_per_level = 3 | ||
width_multipliers = (1, 2, 3, 4) | ||
attn_resolutions = (32, 16, 8) | ||
mha_head_dim = 64 | ||
attn_heads = 8 | ||
resblock_activation = 'silu' | ||
dropout_rate = %DROPOUT_RATE | ||
upsample_mode = 'shuffle' | ||
downsample_mode = 'shuffle' | ||
spatial_skip = False | ||
cond_strategy = 'shift_scale' | ||
norm_32 = True | ||
scale_attn_logits = True | ||
float32_attention_logits = False | ||
text_conditionable = True | ||
|
||
|
||
BASE_SAMPLING_CONFIG = @base_model/samplers.CFGSamplingConfig() | ||
base_model/samplers.CFGSamplingConfig: | ||
num_steps=50 | ||
cf_guidance_weight=5.00 | ||
cf_guidance_nulls=None | ||
|
||
base_model/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
base_model/utils.RestoreCheckpointConfig: | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
base_model/sample_script.DiffusionModelSetupData: | ||
model = %BASE | ||
sampling_cfg = @base_model/samplers.CFGSamplingConfig() | ||
restore_checkpoint_cfg = @base_model/utils.RestoreCheckpointConfig() | ||
partitioner = @partitioning.PjitPartitioner() | ||
input_shapes = {'samples': (1, 64, 64, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN} | ||
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int'} | ||
|
||
#---------------- SR256 Model ------------------------------------------------- | ||
|
||
# ------------------- Model ---------------------------------------------------- | ||
SR256 = @sr256/models.DenoisingDiffusionModel() | ||
SIGMA_DATA = 0.5 | ||
sr256/models.DenoisingDiffusionModel: | ||
denoiser= @sr256/denoisers.EDMTextConditionedSuperResDenoiser() | ||
diffusion_loss= None | ||
diffusion_sampler= @sr256/samplers.EDMSampler() | ||
optimizer_def = None | ||
|
||
# |--- Denoiser | ||
sr256/denoisers.EDMTextConditionedSuperResDenoiser: | ||
raw_model= @sr256/network_sr.ImagenEfficientUNet() | ||
|
||
sr256/samplers.EDMSampler: | ||
dim_noise_scalar = 4. | ||
|
||
# ------------------- Network specification ------------------------------------ | ||
sr256/network_sr.ImagenEfficientUNet.config = @sr256/network_sr.ImagenEfficientUNetConfig() | ||
sr256/network_sr.ImagenEfficientUNetConfig: | ||
dtype = %DTYPE | ||
model_dim = 128 | ||
cond_dim = 512 | ||
attn_cond_dim = 1024 | ||
resblocks_per_level = (2, 4, 8, 8, 2) | ||
width_multipliers = (1, 2, 4, 8, 8) | ||
attn_resolutions_divs = {8: 'fused', 16: 'fused'} | ||
mha_head_dim = 64 | ||
attn_heads = 8 | ||
resblock_activation = 'silu' | ||
resblock_zero_out = True | ||
resblock_scale_skip = True | ||
dropout_rate = %DROPOUT_RATE | ||
cond_strategy = 'shift_scale' | ||
norm_32 = True | ||
scale_attn_logits = True | ||
float32_attention_logits=False | ||
text_conditionable = True | ||
|
||
sr256/samplers.CFGSamplingConfig: | ||
num_steps=50 | ||
cf_guidance_weight=4 | ||
cf_guidance_nulls={'text': None, 'text_mask': None} | ||
|
||
sr256/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
sr256/utils.RestoreCheckpointConfig: | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
sr256/sample_script.DiffusionModelSetupData: | ||
model = %SR256 | ||
sampling_cfg = @sr256/samplers.CFGSamplingConfig() | ||
restore_checkpoint_cfg = @sr256/utils.RestoreCheckpointConfig() | ||
partitioner = @partitioning.PjitPartitioner() | ||
input_shapes = {'samples': (1, 256, 256, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 64, 64, 3)} | ||
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} | ||
|
||
#---------------- Text Model ------------------------------------------------- | ||
import seqio | ||
from rosetta.projects.inference_serving.t5 import network as t5x_network | ||
from rosetta.projects.inference_serving.t5 import models as t5x_models | ||
|
||
# ===================================== | ||
# === T5 Encoder only configuration === | ||
# ===================================== | ||
T5_CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl" | ||
BATCH_SIZE = 256 # Will be overridden | ||
SEQ_LEN = 128 # MAX seqlen | ||
|
||
# Vocabulary | ||
VOCABULARY = @seqio.SentencePieceVocabulary() | ||
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" | ||
TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. | ||
|
||
# --------------- Model ------------------ | ||
TEXT_ENC = @text_enc/t5x_models.EncoderOnlyModel() | ||
text_enc/t5x_models.EncoderOnlyModel: | ||
module = @t5x_network.TransformerEncoderOnly() | ||
input_vocabulary = %VOCABULARY | ||
output_vocabulary = %VOCABULARY | ||
optimizer_def = None | ||
z_loss = 0.0001 | ||
label_smoothing = 0.0 | ||
loss_normalizing_factor = None | ||
|
||
# -------- Network specification --------- | ||
t5x_network.TransformerEncoderOnly.config = @t5x_network.T5Config() | ||
t5x_network.T5Config: | ||
vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency | ||
dtype = 'bfloat16' | ||
emb_dim = 4096 | ||
num_heads = 64 | ||
num_encoder_layers = 24 | ||
num_decoder_layers = 0 | ||
head_dim = 64 | ||
mlp_dim = 10240 | ||
mlp_activations = ('gelu', 'linear') | ||
dropout_rate = 0.0 | ||
|
||
text_enc/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
text_enc/utils.RestoreCheckpointConfig: | ||
path = %T5_CHECKPOINT_PATH | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
text_enc/sample_script.setup_text_enc: | ||
model=%TEXT_ENC | ||
restore_checkpoint_cfg=@text_enc/utils.RestoreCheckpointConfig() | ||
partitioner=@text_enc/partitioning.PjitPartitioner() | ||
batch_size=1 | ||
seq_len=%TXT_SEQLEN_SINGLE | ||
vocab = %VOCABULARY | ||
|
||
sample_script.sample: | ||
base_setupdata = @base_model/sample_script.DiffusionModelSetupData() | ||
sr256_setupdata = @sr256/sample_script.DiffusionModelSetupData() | ||
sr1024_setupdata = None | ||
out_dir = %SAVE_DIR | ||
gen_per_prompt = %GEN_PER_PROMPT | ||
prompt_file = %PROMPT_TEXT_FILE | ||
batch_size = %GLOBAL_BATCH_SIZE | ||
max_images = %MAX_GENERATE | ||
text_enc_infer = @text_enc/sample_script.setup_text_enc() | ||
noise_conditioning_aug = %NOISE_COND_AUG | ||
resume_from = %RESUME_FROM |
Oops, something went wrong.