-
Notifications
You must be signed in to change notification settings - Fork 7.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- fix figure rendering #2 - fix figure rendering - remove GIFs - fix figure names - revise readme - add default input configurations - improve documentations and code quanlity - add ipynb link for generating tflite models - add tflite generative ai example PiperOrigin-RevId: 530782147
- Loading branch information
1 parent
ef4dc5c
commit eb5c721
Showing
61 changed files
with
3,694 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Generative AI | ||
|
||
## Introduction | ||
Large language models (LLMs) are types of machine learning models that are created based on large bodies of text data to generate various outputs for natural language processing (NLP) tasks, including text generation, question answering, and machine translation. They are based on Transformer architecture and are trained on massive amounts of text data, often involving billions of words. Even LLMs of a smaller scale, such as GPT-2, can perform impressively. Converting TensorFlow models to a lighter, faster, and low-power model allows for us to run generative AI models on-device, with benefits of better user security because data will never leave your device. | ||
|
||
This example shows you how to build an Android app with TensorFlow Lite to run a Keras LLM and provides suggestions for model optimization using quantizing techniques, which otherwise would require a much larger amount of memory and greater computational power to run. | ||
|
||
This example open sourced an Android app framework that any compatible TFLite LLMs can plug into. Here are two demos: | ||
* In Figure 1, we used a Keras GPT-2 model to perform text completion tasks on device. | ||
* In Figure 2, we converted a version of instruction-tuned [PaLM model](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html) (1.5 billion parameters) to TFLite and executed through TFLite runtime. | ||
|
||
<p align="center"> | ||
<img src="figures/fig1.gif" width="300"> | ||
</p> | ||
Figure 1: Example of running the Keras GPT-2 model (converted from this Codelab) on device to perform text completion on Pixel 7. Demo shows the real latency with no speedup. | ||
<p align="center"> | ||
<img src="figures/fig2.gif" width="300"> | ||
</p> | ||
Figure 2: Example of running a version of [PaLM model](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html) with 1.5 billion parameters. Demo is recorded on Pixel 7 Pro without playback speedup. | ||
|
||
|
||
## Guides | ||
### Step 1. Train a language model using Keras | ||
|
||
For this demonstration, we will use KerasNLP to get the GPT-2 model. KerasNLP is a library that contains state-of-the-art pretrained models for natural language processing tasks, and can support users through their entire development cycle. You can see the list of models available in the [KerasNLP repository](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models). The workflows are built from modular components that have state-of-the-art preset weights and architectures when used out-of-the-box and are easily customizable when more control is needed. Creating the GPT-2 model can be done with the following steps: | ||
|
||
```python | ||
gpt2_tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en") | ||
|
||
gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | ||
"gpt2_base_en", | ||
sequence_length=256, | ||
add_end_token=True, | ||
) | ||
|
||
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | ||
"gpt2_base_en", | ||
preprocessor=gpt2_preprocessor, | ||
) | ||
``` | ||
|
||
You can check out the full GPT-2 model implementation [on GitHub](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models/gpt2). | ||
|
||
|
||
### Step 2. Convert a Keras model to a TFLite model | ||
|
||
Start with the `generate()` function from GPT2CausalLM that performs the conversion. Wrap the `generate()` function to create a concrete TensorFlow function: | ||
|
||
```python | ||
@tf.function | ||
def generate(prompt, max_length): | ||
# prompt: input prompt to the LLM in string format | ||
# max_length: the max length of the generated tokens | ||
return gpt2_lm.generate(prompt, max_length) | ||
concrete_func = generate.get_concrete_function(tf.TensorSpec([], tf.string), 100) | ||
``` | ||
|
||
Now define a helper function that will run inference with an input and a TFLite model. TensorFlow text ops are not built-in ops in the TFLite runtime, so you will need to add these custom ops in order for the interpreter to make inference on this model. This helper function accepts an input and a function that performs the conversion, namely the `generator()` function defined above. | ||
|
||
```python | ||
def run_inference(input, generate_tflite): | ||
interp = interpreter.InterpreterWithCustomOps( | ||
model_content=generate_tflite, | ||
custom_op_registerers=tf_text.tflite_registrar.SELECT_TFTEXT_OPS) | ||
interp.get_signature_list() | ||
|
||
generator = interp.get_signature_runner('serving_default') | ||
output = generator(prompt=np.array([input])) | ||
``` | ||
|
||
You can convert the model now: | ||
|
||
```python | ||
gpt2_lm.jit_compile = False | ||
converter = tf.lite.TFLiteConverter.from_concrete_functions( | ||
[concrete_func], | ||
gpt2_lm) | ||
|
||
converter.target_spec.supported_ops = [ | ||
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TFLite ops | ||
tf.lite.OpsSet.SELECT_TF_OPS, # enable TF ops | ||
] | ||
converter.allow_custom_ops = True | ||
converter.target_spec.experimental_select_user_tf_ops = [ | ||
"UnsortedSegmentJoin", | ||
"UpperBound" | ||
] | ||
converter._experimental_guarantee_all_funcs_one_use = True | ||
generate_tflite = converter.convert() | ||
run_inference("I'm enjoying a", generate_tflite) | ||
``` | ||
|
||
### Step 3. Quantization | ||
TensorFlow Lite has implemented an optimization technique called quantization which can reduce model size and accelerate inference. Through the quantization process, 32-bit floats are mapped to smaller 8-bit integers, therefore reducing the model size by a factor of 4 for more efficient execution on modern hardwares. There are several ways to do quantization in TensorFlow. You can visit the [TFLite Model optimization](https://www.tensorflow.org/lite/performance/model_optimization) and [TensorFlow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization) pages for more information. The types of quantizations are explained briefly below. | ||
|
||
Here, you will use the post-training dynamic range quantization on the GPT-2 model by setting the converter optimization flag to tf.lite.Optimize.DEFAULT, and the rest of the conversion process is the same as detailed before. We tested that with this quantization technique the latency is around 6.7 seconds on Pixel 7 with max output length set to 100. | ||
|
||
```python | ||
gpt2_lm.jit_compile = False | ||
converter = tf.lite.TFLiteConverter.from_concrete_functions( | ||
[concrete_func], | ||
gpt2_lm) | ||
|
||
converter.target_spec.supported_ops = [ | ||
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TFLite ops | ||
tf.lite.OpsSet.SELECT_TF_OPS, # enable TF ops | ||
] | ||
converter.allow_custom_ops = True | ||
converter.optimizations = [tf.lite.Optimize.DEFAULT] | ||
converter.target_spec.experimental_select_user_tf_ops = [ | ||
"UnsortedSegmentJoin", | ||
"UpperBound" | ||
] | ||
converter._experimental_guarantee_all_funcs_one_use = True | ||
quant_generate_tflite = converter.convert() | ||
run_inference("I'm enjoying a", quant_generate_tflite) | ||
|
||
with open('quantized_gpt2.tflite', 'wb') as f: | ||
f.write(quant_generate_tflite) | ||
``` | ||
|
||
|
||
|
||
### Step 4. Android App integration | ||
|
||
You can clone this repo and substitute `android/app/src/main/assets/autocomplete.tflite` with your converted `quant_generate_tflite` file. Please refer to [how-to-build.md](https://github.com/tensorflow/examples/blob/master/lite/examples/generative_ai/android/how-to-build.md) to build this Android App. | ||
|
||
## Safety and Responsible AI | ||
As noted in the original [OpenAI GPT-2 announcement](https://openai.com/research/better-language-models), there are [notable caveats and limitations](https://github.com/openai/gpt-2#some-caveats) with the GPT-2 model. In fact, LLMs today generally have some well-known challenges such as hallucinations, fairness, and bias; this is because these models are trained on real-world data, which make them reflect real world issues. | ||
This codelab is created only to demonstrate how to create an app powered by LLMs with TensorFlow tooling. The model produced in this codelab is for educational purposes only and not intended for production usage. | ||
LLM production usage requires thoughtful selection of training datasets and comprehensive safety mitigations. One such functionality offered in this Android app is the profanity filter, which rejects bad user inputs or model outputs. If any inappropriate language is detected, the app will in return reject that action. To learn more about Responsible AI in the context of LLMs, make sure to watch the Safe and Responsible Development with Generative Language Models technical session at Google I/O 2023 and check out the [Responsible AI Toolkit](https://www.tensorflow.org/responsible_ai). |
102 changes: 102 additions & 0 deletions
102
lite/examples/generative_ai/android/app/build.gradle.kts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
@file:Suppress("UnstableApiUsage") | ||
|
||
plugins { | ||
kotlin("android") | ||
id("com.android.application") | ||
id("de.undercouch.download") | ||
} | ||
|
||
ext { | ||
set("AAR_URL", "https://storage.googleapis.com/download.tensorflow.org/models/tflite/generativeai/tensorflow-lite-select-tf-ops.aar") | ||
set("AAR_PATH", "$projectDir/libs/tensorflow-lite-select-tf-ops.aar") | ||
} | ||
|
||
apply { | ||
from("download.gradle") | ||
} | ||
|
||
android { | ||
namespace = "com.google.tensorflowdemo" | ||
compileSdk = 33 | ||
|
||
defaultConfig { | ||
applicationId = "com.google.tensorflowdemo" | ||
minSdk = 24 | ||
targetSdk = 33 | ||
versionCode = 1 | ||
versionName = "1.0" | ||
} | ||
buildFeatures { | ||
compose = true | ||
buildConfig = true | ||
viewBinding = true | ||
} | ||
composeOptions { | ||
kotlinCompilerExtensionVersion = "1.3.2" | ||
} | ||
packagingOptions { | ||
resources { | ||
excludes += "/META-INF/{AL2.0,LGPL2.1}" | ||
} | ||
} | ||
buildTypes { | ||
getByName("release") { | ||
isMinifyEnabled = true | ||
proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro") | ||
isDebuggable = false | ||
} | ||
getByName("debug") { | ||
applicationIdSuffix = ".debug" | ||
} | ||
} | ||
compileOptions { | ||
sourceCompatibility = JavaVersion.VERSION_1_8 | ||
targetCompatibility = JavaVersion.VERSION_1_8 | ||
} | ||
kotlinOptions { | ||
jvmTarget = "1.8" | ||
freeCompilerArgs = listOf( | ||
"-P", | ||
"plugin:androidx.compose.compiler.plugins.kotlin:suppressKotlinVersionCompatibilityCheck=1.8.10" | ||
) | ||
} | ||
} | ||
|
||
dependencies { | ||
implementation(fileTree(mapOf("dir" to "libs", "include" to listOf("*.aar")))) | ||
|
||
// Compose | ||
implementation(libraries.compose.ui) | ||
implementation(libraries.compose.ui.tooling) | ||
implementation(libraries.compose.ui.tooling.preview) | ||
implementation(libraries.compose.foundation) | ||
implementation(libraries.compose.material) | ||
implementation(libraries.compose.material.icons) | ||
implementation(libraries.compose.activity) | ||
|
||
// Accompanist for Compose | ||
implementation(libraries.accompanist.systemuicontroller) | ||
|
||
// Koin | ||
implementation(libraries.koin.core) | ||
implementation(libraries.koin.android) | ||
implementation(libraries.koin.compose) | ||
|
||
// Lifecycle | ||
implementation(libraries.lifecycle.viewmodel) | ||
implementation(libraries.lifecycle.viewmodel.compose) | ||
implementation(libraries.lifecycle.viewmodel.ktx) | ||
implementation(libraries.lifecycle.runtime.compose) | ||
|
||
// Logging | ||
implementation(libraries.napier) | ||
|
||
// Profanity filter | ||
implementation(libraries.wordfilter) | ||
|
||
// TensorFlow Lite | ||
implementation(libraries.tflite) | ||
|
||
// Unit tests | ||
testImplementation(libraries.junit) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
task downloadAAR { | ||
download { | ||
src project.ext.AAR_URL | ||
dest project.ext.AAR_PATH | ||
overwrite false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tensorflow-lite-select-tf-ops.aar |
16 changes: 16 additions & 0 deletions
16
lite/examples/generative_ai/android/app/libs/build_aar/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Build your own aar | ||
|
||
By default the app automatically downloads the needed aar files. But if you want | ||
to build your own, just go ahead and run `./build_aar.sh`. This script will pull | ||
in the necessary ops from [TensorFlow Text](https://www.tensorflow.org/text) and | ||
build the aar for [Select TF operators](https://www.tensorflow.org/lite/guide/ops_select). | ||
|
||
After compilation, a new file `tftext_tflite_flex.aar` is generated. Replace the | ||
one in app/libs/ folder and re-build the app. | ||
|
||
By default, the script builds only for `android_x86_64`. You can change it to | ||
`android_x86`, `android_arm` or `android_arm64`. | ||
|
||
Note that you still need to include the standard `tensorflow-lite` aar in your | ||
gradle file. | ||
|
28 changes: 28 additions & 0 deletions
28
lite/examples/generative_ai/android/app/libs/build_aar/build_aar.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#! /bin/bash | ||
|
||
set -e | ||
|
||
# Clone TensorFlow Text repo | ||
git clone https://github.com/tensorflow/text.git tensorflow_text | ||
|
||
cd tensorflow_text/ | ||
echo 'exports_files(["LICENSE"])' > BUILD | ||
|
||
# Checkout 2.12 branch | ||
git checkout 2.12 | ||
|
||
# Apply tftext-2.12.patch | ||
git apply ../tftext-2.12.patch | ||
|
||
# Run config | ||
./oss_scripts/configure.sh | ||
|
||
# Run bazel build | ||
bazel build -c opt --cxxopt='--std=c++14' --config=monolithic --config=android_x86_64 --experimental_repo_remote_exec //tensorflow_text:tftext_tflite_flex | ||
|
||
if [ $? -eq 0 ]; then | ||
# Print a message | ||
echo "Please find the aar file: tensorflow_text/bazel-bin/tensorflow_text/tftext_tflite_flex.aar" | ||
else | ||
echo "build_aar.sh has failed. Please find the error message above and address it before proceeding." | ||
fi |
61 changes: 61 additions & 0 deletions
61
lite/examples/generative_ai/android/app/libs/build_aar/tftext-2.12.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
diff --git a/WORKSPACE b/WORKSPACE | ||
index 28b7ee5..5ad0b55 100644 | ||
--- a/WORKSPACE | ||
+++ b/WORKSPACE | ||
@@ -116,3 +116,10 @@ load("@org_tensorflow//third_party/android:android_configure.bzl", "android_conf | ||
android_configure(name="local_config_android") | ||
load("@local_config_android//:android.bzl", "android_workspace") | ||
android_workspace() | ||
+ | ||
+android_sdk_repository(name = "androidsdk") | ||
+ | ||
+android_ndk_repository( | ||
+ name = "androidndk", | ||
+ api_level = 21, | ||
+) | ||
diff --git a/tensorflow_text/BUILD b/tensorflow_text/BUILD | ||
index 9b5ee5b..880c7c5 100644 | ||
--- a/tensorflow_text/BUILD | ||
+++ b/tensorflow_text/BUILD | ||
@@ -2,6 +2,8 @@ load("//tensorflow_text:tftext.bzl", "py_tf_text_library") | ||
|
||
# [internal] load build_test.bzl | ||
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object") | ||
+load("@org_tensorflow//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_android_library") | ||
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") | ||
|
||
# Visibility rules | ||
package( | ||
@@ -61,6 +63,20 @@ tflite_cc_shared_object( | ||
deps = [":ops_lib"], | ||
) | ||
|
||
+tflite_flex_android_library( | ||
+ name = "tftext_ops", | ||
+ additional_deps = [ | ||
+ "@org_tensorflow//tensorflow/lite/delegates/flex:delegate", | ||
+ ":ops_lib", | ||
+ ], | ||
+ visibility = ["//visibility:public"], | ||
+) | ||
+ | ||
+aar_with_jni( | ||
+ name = "tftext_tflite_flex", | ||
+ android_library = ":tftext_ops", | ||
+) | ||
+ | ||
py_library( | ||
name = "ops", | ||
srcs = [ | ||
diff --git a/tensorflow_text/tftext.bzl b/tensorflow_text/tftext.bzl | ||
index 65430ca..04f68d8 100644 | ||
--- a/tensorflow_text/tftext.bzl | ||
+++ b/tensorflow_text/tftext.bzl | ||
@@ -140,6 +140,7 @@ def tf_cc_library( | ||
deps += select({ | ||
"@org_tensorflow//tensorflow:mobile": [ | ||
"@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", | ||
+ "@org_tensorflow//tensorflow/lite/kernels/shim:tf_op_shim", | ||
], | ||
"//conditions:default": [ | ||
"@local_config_tf//:libtensorflow_framework", |
21 changes: 21 additions & 0 deletions
21
lite/examples/generative_ai/android/app/proguard-rules.pro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Add project specific ProGuard rules here. | ||
# You can control the set of applied configuration files using the | ||
# proguardFiles setting in build.gradle. | ||
# | ||
# For more details, see | ||
# http://developer.android.com/guide/developing/tools/proguard.html | ||
|
||
# If your project uses WebView with JS, uncomment the following | ||
# and specify the fully qualified class name to the JavaScript interface | ||
# class: | ||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview { | ||
# public *; | ||
#} | ||
|
||
# Uncomment this to preserve the line number information for | ||
# debugging stack traces. | ||
#-keepattributes SourceFile,LineNumberTable | ||
|
||
# If you keep the line number information, uncomment this to | ||
# hide the original source file name. | ||
#-renamesourcefileattribute SourceFile |
Oops, something went wrong.