diff --git a/tools/inference/inference_pipeline.ipynb b/tools/inference/inference_pipeline.ipynb
index 43455fe43..56a3bb187 100644
--- a/tools/inference/inference_pipeline.ipynb
+++ b/tools/inference/inference_pipeline.ipynb
@@ -1,561 +1,561 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "118UKH5bWCGa"
- },
- "source": [
- "# DALL·E mini - Inference pipeline\n",
- "\n",
- "*Generate images from a text prompt*\n",
- "\n",
- "\n",
- "\n",
- "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
- "\n",
- "Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
- "\n",
- "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dS8LbaonYm3a"
- },
- "source": [
- "## 🛠️ Installation and set-up"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "uzjAM2GBYpZX"
- },
- "outputs": [],
- "source": [
- "# Required only for colab environments + GPU\n",
- "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
- "\n",
- "# Install required libraries\n",
- "!pip install -q dalle-mini orbax==0.0.23\n",
- "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ozHzTkyv8cqU"
- },
- "source": [
- "We load required models:\n",
- "* DALL·E mini for text to encoded images\n",
- "* VQGAN for decoding images\n",
- "* CLIP for scoring predictions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "K6CxW2o42f-w"
- },
- "outputs": [],
- "source": [
- "# Model references\n",
- "\n",
- "# dalle-mega\n",
- "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
- "DALLE_COMMIT_ID = None\n",
- "\n",
- "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
- "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
- "\n",
- "# VQGAN model\n",
- "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
- "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Yv-aR3t4Oe5v"
- },
- "outputs": [],
- "source": [
- "import jax\n",
- "import jax.numpy as jnp\n",
- "\n",
- "# check how many devices are available\n",
- "jax.local_device_count()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "92zYmvsQ38vL"
- },
- "outputs": [],
- "source": [
- "# Load models & tokenizer\n",
- "from dalle_mini import DalleBart, DalleBartProcessor\n",
- "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
- "from transformers import CLIPProcessor, FlaxCLIPModel\n",
- "\n",
- "# Load dalle-mini\n",
- "model, params = DalleBart.from_pretrained(\n",
- " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
- ")\n",
- "\n",
- "# Load VQGAN\n",
- "vqgan, vqgan_params = VQModel.from_pretrained(\n",
- " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "o_vH2X1tDtzA"
- },
- "source": [
- "Model parameters are replicated on each device for faster inference."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "wtvLoM48EeVw"
- },
- "outputs": [],
- "source": [
- "from flax.jax_utils import replicate\n",
- "\n",
- "params = replicate(params)\n",
- "vqgan_params = replicate(vqgan_params)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0A9AHQIgZ_qw"
- },
- "source": [
- "Model functions are compiled and parallelized to take advantage of multiple devices."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "sOtoOmYsSYPz"
- },
- "outputs": [],
- "source": [
- "from functools import partial\n",
- "\n",
- "\n",
- "# model inference\n",
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
- "def p_generate(\n",
- " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
- "):\n",
- " return model.generate(\n",
- " **tokenized_prompt,\n",
- " prng_key=key,\n",
- " params=params,\n",
- " top_k=top_k,\n",
- " top_p=top_p,\n",
- " temperature=temperature,\n",
- " condition_scale=condition_scale,\n",
- " )\n",
- "\n",
- "\n",
- "# decode image\n",
- "@partial(jax.pmap, axis_name=\"batch\")\n",
- "def p_decode(indices, params):\n",
- " return vqgan.decode_code(indices, params=params)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HmVN6IBwapBA"
- },
- "source": [
- "Keys are passed to the model on each device to generate unique inference per device."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "4CTXmlUkThhX"
- },
- "outputs": [],
- "source": [
- "import random\n",
- "\n",
- "# create a random key\n",
- "seed = random.randint(0, 2**32 - 1)\n",
- "key = jax.random.PRNGKey(seed)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "BrnVyCo81pij"
- },
- "source": [
- "## 🖍 Text Prompt"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rsmj0Aj5OQox"
- },
- "source": [
- "Our model requires processing prompts."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "YjjhUychOVxm"
- },
- "outputs": [],
- "source": [
- "from dalle_mini import DalleBartProcessor\n",
- "\n",
- "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "BQ7fymSPyvF_"
- },
- "source": [
- "Let's define some text prompts."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "x_0vI9ge1oKr"
- },
- "outputs": [],
- "source": [
- "prompts = [\n",
- " \"sunset over a lake in the mountains\",\n",
- " \"the Eiffel tower landing on the moon\",\n",
- "]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XlZUG3SCLnGE"
- },
- "source": [
- "Note: we could use the same prompt multiple times for faster inference."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "VKjEZGjtO49k"
- },
- "outputs": [],
- "source": [
- "tokenized_prompts = processor(prompts)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-CEJBnuJOe5z"
- },
- "source": [
- "Finally we replicate the prompts onto each device."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "lQePgju5Oe5z"
- },
- "outputs": [],
- "source": [
- "tokenized_prompt = replicate(tokenized_prompts)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "phQ9bhjRkgAZ"
- },
- "source": [
- "## 🎨 Generate images\n",
- "\n",
- "We generate images using dalle-mini model and decode them with the VQGAN."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "d0wVkXpKqnHA"
- },
- "outputs": [],
- "source": [
- "# number of predictions per prompt\n",
- "n_predictions = 8\n",
- "\n",
- "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
- "gen_top_k = None\n",
- "gen_top_p = None\n",
- "temperature = None\n",
- "cond_scale = 10.0"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "SDjEx9JxR3v8"
- },
- "outputs": [],
- "source": [
- "from flax.training.common_utils import shard_prng_key\n",
- "import numpy as np\n",
- "from PIL import Image\n",
- "from tqdm.notebook import trange\n",
- "\n",
- "print(f\"Prompts: {prompts}\\n\")\n",
- "# generate images\n",
- "images = []\n",
- "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
- " # get a new key\n",
- " key, subkey = jax.random.split(key)\n",
- " # generate images\n",
- " encoded_images = p_generate(\n",
- " tokenized_prompt,\n",
- " shard_prng_key(subkey),\n",
- " params,\n",
- " gen_top_k,\n",
- " gen_top_p,\n",
- " temperature,\n",
- " cond_scale,\n",
- " )\n",
- " # remove BOS\n",
- " encoded_images = encoded_images.sequences[..., 1:]\n",
- " # decode images\n",
- " decoded_images = p_decode(encoded_images, vqgan_params)\n",
- " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
- " for decoded_img in decoded_images:\n",
- " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
- " images.append(img)\n",
- " display(img)\n",
- " print()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "tw02wG9zGmyB"
- },
- "source": [
- "## 🏅 Optional: Rank images by CLIP score\n",
- "\n",
- "We can rank images according to CLIP.\n",
- "\n",
- "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "RGjlIW_f6GA0"
- },
- "outputs": [],
- "source": [
- "# CLIP model\n",
- "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
- "CLIP_COMMIT_ID = None\n",
- "\n",
- "# Load CLIP\n",
- "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
- " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
- ")\n",
- "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
- "clip_params = replicate(clip_params)\n",
- "\n",
- "\n",
- "# score images\n",
- "@partial(jax.pmap, axis_name=\"batch\")\n",
- "def p_clip(inputs, params):\n",
- " logits = clip(params=params, **inputs).logits_per_image\n",
- " return logits"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "FoLXpjCmGpju"
- },
- "outputs": [],
- "source": [
- "from flax.training.common_utils import shard\n",
- "\n",
- "# get clip scores\n",
- "clip_inputs = clip_processor(\n",
- " text=prompts * jax.device_count(),\n",
- " images=images,\n",
- " return_tensors=\"np\",\n",
- " padding=\"max_length\",\n",
- " max_length=77,\n",
- " truncation=True,\n",
- ").data\n",
- "logits = p_clip(shard(clip_inputs), clip_params)\n",
- "\n",
- "# organize scores per prompt\n",
- "p = len(prompts)\n",
- "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "4AAWRm70LgED"
- },
- "source": [
- "Let's now display images ranked by CLIP score."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "zsgxxubLLkIu"
- },
- "outputs": [],
- "source": [
- "for i, prompt in enumerate(prompts):\n",
- " print(f\"Prompt: {prompt}\\n\")\n",
- " for idx in logits[i].argsort()[::-1]:\n",
- " display(images[idx * p + i])\n",
- " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
- " print()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "oZT9i3jCjir0"
- },
- "source": [
- "## 🪄 Optional: Save your Generated Images as W&B Tables\n",
- "\n",
- "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "-pSiv6Vwjkn0"
- },
- "outputs": [],
- "source": [
- "import wandb\n",
- "\n",
- "# Initialize a W&B run.\n",
- "project = \"dalle-mini-tables-colab\"\n",
- "run = wandb.init(project=project)\n",
- "\n",
- "# Initialize an empty W&B Tables.\n",
- "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
- "gen_table = wandb.Table(columns=columns)\n",
- "\n",
- "# Add data to the table.\n",
- "for i, prompt in enumerate(prompts):\n",
- " # If CLIP scores exist, sort the Images\n",
- " if logits is not None:\n",
- " idxs = logits[i].argsort()[::-1]\n",
- " tmp_imgs = images[i :: len(prompts)]\n",
- " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
- " else:\n",
- " tmp_imgs = images[i :: len(prompts)]\n",
- "\n",
- " # Add the data to the table.\n",
- " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
- "\n",
- "# Log the Table to W&B dashboard.\n",
- "wandb.log({\"Generated Images\": gen_table})\n",
- "\n",
- "# Close the W&B run.\n",
- "run.finish()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Ck2ZnHwVjnRd"
- },
- "source": [
- "Click on the link above to check out your generated images."
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "machine_shape": "hm",
- "name": "DALL·E mini - Inference pipeline.ipynb",
- "provenance": [],
- "gpuType": "A100",
- "include_colab_link": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.7"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "118UKH5bWCGa"
+ },
+ "source": [
+ "# DALL·E mini - Inference pipeline\n",
+ "\n",
+ "*Generate images from a text prompt*\n",
+ "\n",
+ "\n",
+ "\n",
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
+ "\n",
+ "Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
+ "\n",
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dS8LbaonYm3a"
+ },
+ "source": [
+ "## 🛠️ Installation and set-up"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "uzjAM2GBYpZX"
+ },
+ "outputs": [],
+ "source": [
+ "# Required only for colab environments + GPU\n",
+ "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
+ "\n",
+ "# Install required libraries\n",
+ "!pip install -q dalle-mini\n",
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ozHzTkyv8cqU"
+ },
+ "source": [
+ "We load required models:\n",
+ "* DALL·E mini for text to encoded images\n",
+ "* VQGAN for decoding images\n",
+ "* CLIP for scoring predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "K6CxW2o42f-w"
+ },
+ "outputs": [],
+ "source": [
+ "# Model references\n",
+ "\n",
+ "# dalle-mega\n",
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
+ "DALLE_COMMIT_ID = None\n",
+ "\n",
+ "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
+ "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
+ "\n",
+ "# VQGAN model\n",
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Yv-aR3t4Oe5v"
+ },
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "# check how many devices are available\n",
+ "jax.local_device_count()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "92zYmvsQ38vL"
+ },
+ "outputs": [],
+ "source": [
+ "# Load models & tokenizer\n",
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
+ "\n",
+ "# Load dalle-mini\n",
+ "model, params = DalleBart.from_pretrained(\n",
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
+ ")\n",
+ "\n",
+ "# Load VQGAN\n",
+ "vqgan, vqgan_params = VQModel.from_pretrained(\n",
+ " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "o_vH2X1tDtzA"
+ },
+ "source": [
+ "Model parameters are replicated on each device for faster inference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "wtvLoM48EeVw"
+ },
+ "outputs": [],
+ "source": [
+ "from flax.jax_utils import replicate\n",
+ "\n",
+ "params = replicate(params)\n",
+ "vqgan_params = replicate(vqgan_params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0A9AHQIgZ_qw"
+ },
+ "source": [
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sOtoOmYsSYPz"
+ },
+ "outputs": [],
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "\n",
+ "# model inference\n",
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
+ "def p_generate(\n",
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
+ "):\n",
+ " return model.generate(\n",
+ " **tokenized_prompt,\n",
+ " prng_key=key,\n",
+ " params=params,\n",
+ " top_k=top_k,\n",
+ " top_p=top_p,\n",
+ " temperature=temperature,\n",
+ " condition_scale=condition_scale,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "# decode image\n",
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
+ "def p_decode(indices, params):\n",
+ " return vqgan.decode_code(indices, params=params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HmVN6IBwapBA"
+ },
+ "source": [
+ "Keys are passed to the model on each device to generate unique inference per device."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4CTXmlUkThhX"
+ },
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "\n",
+ "# create a random key\n",
+ "seed = random.randint(0, 2**32 - 1)\n",
+ "key = jax.random.PRNGKey(seed)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BrnVyCo81pij"
+ },
+ "source": [
+ "## 🖍 Text Prompt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rsmj0Aj5OQox"
+ },
+ "source": [
+ "Our model requires processing prompts."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "YjjhUychOVxm"
+ },
+ "outputs": [],
+ "source": [
+ "from dalle_mini import DalleBartProcessor\n",
+ "\n",
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BQ7fymSPyvF_"
+ },
+ "source": [
+ "Let's define some text prompts."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "x_0vI9ge1oKr"
+ },
+ "outputs": [],
+ "source": [
+ "prompts = [\n",
+ " \"sunset over a lake in the mountains\",\n",
+ " \"the Eiffel tower landing on the moon\",\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XlZUG3SCLnGE"
+ },
+ "source": [
+ "Note: we could use the same prompt multiple times for faster inference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VKjEZGjtO49k"
+ },
+ "outputs": [],
+ "source": [
+ "tokenized_prompts = processor(prompts)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-CEJBnuJOe5z"
+ },
+ "source": [
+ "Finally we replicate the prompts onto each device."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lQePgju5Oe5z"
+ },
+ "outputs": [],
+ "source": [
+ "tokenized_prompt = replicate(tokenized_prompts)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "phQ9bhjRkgAZ"
+ },
+ "source": [
+ "## 🎨 Generate images\n",
+ "\n",
+ "We generate images using dalle-mini model and decode them with the VQGAN."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "d0wVkXpKqnHA"
+ },
+ "outputs": [],
+ "source": [
+ "# number of predictions per prompt\n",
+ "n_predictions = 8\n",
+ "\n",
+ "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
+ "gen_top_k = None\n",
+ "gen_top_p = None\n",
+ "temperature = None\n",
+ "cond_scale = 10.0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SDjEx9JxR3v8"
+ },
+ "outputs": [],
+ "source": [
+ "from flax.training.common_utils import shard_prng_key\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "from tqdm.notebook import trange\n",
+ "\n",
+ "print(f\"Prompts: {prompts}\\n\")\n",
+ "# generate images\n",
+ "images = []\n",
+ "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
+ " # get a new key\n",
+ " key, subkey = jax.random.split(key)\n",
+ " # generate images\n",
+ " encoded_images = p_generate(\n",
+ " tokenized_prompt,\n",
+ " shard_prng_key(subkey),\n",
+ " params,\n",
+ " gen_top_k,\n",
+ " gen_top_p,\n",
+ " temperature,\n",
+ " cond_scale,\n",
+ " )\n",
+ " # remove BOS\n",
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
+ " # decode images\n",
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
+ " for decoded_img in decoded_images:\n",
+ " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
+ " images.append(img)\n",
+ " display(img)\n",
+ " print()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tw02wG9zGmyB"
+ },
+ "source": [
+ "## 🏅 Optional: Rank images by CLIP score\n",
+ "\n",
+ "We can rank images according to CLIP.\n",
+ "\n",
+ "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RGjlIW_f6GA0"
+ },
+ "outputs": [],
+ "source": [
+ "# CLIP model\n",
+ "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
+ "CLIP_COMMIT_ID = None\n",
+ "\n",
+ "# Load CLIP\n",
+ "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
+ " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
+ ")\n",
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
+ "clip_params = replicate(clip_params)\n",
+ "\n",
+ "\n",
+ "# score images\n",
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
+ "def p_clip(inputs, params):\n",
+ " logits = clip(params=params, **inputs).logits_per_image\n",
+ " return logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FoLXpjCmGpju"
+ },
+ "outputs": [],
+ "source": [
+ "from flax.training.common_utils import shard\n",
+ "\n",
+ "# get clip scores\n",
+ "clip_inputs = clip_processor(\n",
+ " text=prompts * jax.device_count(),\n",
+ " images=images,\n",
+ " return_tensors=\"np\",\n",
+ " padding=\"max_length\",\n",
+ " max_length=77,\n",
+ " truncation=True,\n",
+ ").data\n",
+ "logits = p_clip(shard(clip_inputs), clip_params)\n",
+ "\n",
+ "# organize scores per prompt\n",
+ "p = len(prompts)\n",
+ "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4AAWRm70LgED"
+ },
+ "source": [
+ "Let's now display images ranked by CLIP score."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zsgxxubLLkIu"
+ },
+ "outputs": [],
+ "source": [
+ "for i, prompt in enumerate(prompts):\n",
+ " print(f\"Prompt: {prompt}\\n\")\n",
+ " for idx in logits[i].argsort()[::-1]:\n",
+ " display(images[idx * p + i])\n",
+ " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
+ " print()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oZT9i3jCjir0"
+ },
+ "source": [
+ "## 🪄 Optional: Save your Generated Images as W&B Tables\n",
+ "\n",
+ "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-pSiv6Vwjkn0"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "\n",
+ "# Initialize a W&B run.\n",
+ "project = \"dalle-mini-tables-colab\"\n",
+ "run = wandb.init(project=project)\n",
+ "\n",
+ "# Initialize an empty W&B Tables.\n",
+ "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
+ "gen_table = wandb.Table(columns=columns)\n",
+ "\n",
+ "# Add data to the table.\n",
+ "for i, prompt in enumerate(prompts):\n",
+ " # If CLIP scores exist, sort the Images\n",
+ " if logits is not None:\n",
+ " idxs = logits[i].argsort()[::-1]\n",
+ " tmp_imgs = images[i :: len(prompts)]\n",
+ " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
+ " else:\n",
+ " tmp_imgs = images[i :: len(prompts)]\n",
+ "\n",
+ " # Add the data to the table.\n",
+ " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
+ "\n",
+ "# Log the Table to W&B dashboard.\n",
+ "wandb.log({\"Generated Images\": gen_table})\n",
+ "\n",
+ "# Close the W&B run.\n",
+ "run.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ck2ZnHwVjnRd"
+ },
+ "source": [
+ "Click on the link above to check out your generated images."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "machine_shape": "hm",
+ "name": "DALL·E mini - Inference pipeline.ipynb",
+ "provenance": [],
+ "gpuType": "A100",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
\ No newline at end of file