[go: nahoru, domu]

Skip to content

Commit

Permalink
Add an example to demonstrate Keras(with JAX backend) to TFLite.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631821076
  • Loading branch information
tensorflower-gardener committed May 8, 2024
1 parent 1b70cc4 commit 604adec
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tensorflow/lite/g3doc/examples/keras/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
load("//third_party/py/tensorflow_docs/google:tf_org.bzl", "tf_org_check_links", "tf_org_notebook_strict_test")

# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])

licenses(["notice"])

tf_org_check_links(name = "check_links")

tf_org_notebook_strict_test(
name = "jax_backend_resnet50_example",
size = "medium",
ipynb = "keras_jax_backend_to_tfl.ipynb",
tags = [
"requires-net:external",
],
deps = [
"//third_party/py/PIL:pil",
"//third_party/py/keras",
"//third_party/py/numpy",
"//third_party/py/requests",
],
)
263 changes: 263 additions & 0 deletions tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "eB-dvsMI09O4"
},
"source": [
"##### Copyright 2024 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "mwvC53CC1K3n"
},
"outputs": [],
"source": [
"# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w61LHyw7v_yx"
},
"source": [
"Converting Keras to TFLite (via the JAX backend)\n",
"=========="
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zQXWQ7y11eIR"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/jax_conversion/jax_to_tflite_resnet50\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2-t7bCE0lGsH"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SNUGBsILwSSs"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ht0UjgDxliW9"
},
"outputs": [],
"source": [
"import keras\n",
"import tensorflow as tf\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kZhJOer0wXWP"
},
"source": [
"## Get the test image data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FirUqiycez0X"
},
"outputs": [],
"source": [
"from PIL import Image\n",
"import requests\n",
"\n",
"url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg\"\n",
"image = Image.open(requests.get(url, stream=True).raw)\n",
"image = image.resize((224, 224))\n",
"input_image = np.array(image)\n",
"input_image = np.expand_dims(input_image, axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CHMbP6ZWwcV_"
},
"source": [
"## Instatiate a Resnet50 model from the Keras models library"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CbJyUj1IoqF6"
},
"outputs": [],
"source": [
"jax_model = keras.applications.resnet.ResNet50(include_top=True, weights=\"imagenet\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B5yqkEEXwo13"
},
"source": [
"## Run the keras JAX model with the test input"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IBHBetUqfhDA"
},
"outputs": [],
"source": [
"input_data = keras.applications.resnet50.preprocess_input(input_image)\n",
"jax_model_output = jax_model(input_data)\n",
"\n",
"decoded_preds = keras.applications.resnet.decode_predictions(jax_model_output, top=1)[\n",
" 0\n",
"][0]\n",
"print(\"Predicted class:\", decoded_preds[1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HaUJGHGpw0KD"
},
"source": [
"## Save the Keras JAX model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IZ4YqZLTrGc6"
},
"outputs": [],
"source": [
"saved_model_dir = \"resnet50_saved_model\"\n",
"jax_model.export(saved_model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bNSGGXfhw4uO"
},
"source": [
"## Convert to a TFLite model file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MdJz2eKqsEhA"
},
"outputs": [],
"source": [
"converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
"tflite_model = converter.convert()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NZ4DjfGSWS7O"
},
"source": [
"## Run using TFLite Runtime"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CtMSYAkwWWVm"
},
"outputs": [],
"source": [
"interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
"interpreter.allocate_tensors()\n",
"\n",
"input_details = interpreter.get_input_details()[0]\n",
"interpreter.set_tensor(input_details[\"index\"], input_data)\n",
"interpreter.invoke()\n",
"\n",
"output_details = interpreter.get_output_details()\n",
"output_data = interpreter.get_tensor(output_details[0][\"index\"])\n",
"\n",
"tfl_predicted_class_idx = keras.applications.resnet.decode_predictions(\n",
" output_data, top=1\n",
")[0][0]\n",
"print(\"Predicted class:\", tfl_predicted_class_idx[1])"
]
}
],
"metadata": {
"colab": {
"name": "keras_jax_backend_to_tfl.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 604adec

Please sign in to comment.