diff --git a/tensorflow/lite/g3doc/examples/keras/BUILD b/tensorflow/lite/g3doc/examples/keras/BUILD deleted file mode 100644 index a03fdf31d62966..00000000000000 --- a/tensorflow/lite/g3doc/examples/keras/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//third_party/py/tensorflow_docs/google:tf_org.bzl", "tf_org_check_links", "tf_org_notebook_strict_test") - -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) - -licenses(["notice"]) - -tf_org_check_links(name = "check_links") - -tf_org_notebook_strict_test( - name = "jax_backend_resnet50_example", - size = "medium", - ipynb = "keras_jax_backend_to_tfl.ipynb", - tags = [ - "requires-net:external", - ], - deps = [ - "//third_party/py/PIL:pil", - "//third_party/py/keras", - "//third_party/py/numpy", - "//third_party/py/requests", - ], -) diff --git a/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb b/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb deleted file mode 100644 index ec811e81b68dd6..00000000000000 --- a/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb +++ /dev/null @@ -1,263 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "eB-dvsMI09O4" - }, - "source": [ - "##### Copyright 2024 The TensorFlow Authors." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "mwvC53CC1K3n" - }, - "outputs": [], - "source": [ - "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w61LHyw7v_yx" - }, - "source": [ - "Converting Keras to TFLite (via the JAX backend)\n", - "==========" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zQXWQ7y11eIR" - }, - "source": [ - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/jax_conversion/jax_to_tflite_resnet50\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", - " \u003c/td\u003e\n", - "\u003c/table\u003e" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2-t7bCE0lGsH" - }, - "outputs": [], - "source": [ - "import os\n", - "\n", - "os.environ[\"KERAS_BACKEND\"] = \"jax\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SNUGBsILwSSs" - }, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ht0UjgDxliW9" - }, - "outputs": [], - "source": [ - "import keras\n", - "import tensorflow as tf\n", - "import numpy as np" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kZhJOer0wXWP" - }, - "source": [ - "## Get the test image data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FirUqiycez0X" - }, - "outputs": [], - "source": [ - "from PIL import Image\n", - "import requests\n", - "\n", - "url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg\"\n", - "image = Image.open(requests.get(url, stream=True).raw)\n", - "image = image.resize((224, 224))\n", - "input_image = np.array(image)\n", - "input_image = np.expand_dims(input_image, axis=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CHMbP6ZWwcV_" - }, - "source": [ - "## Instatiate a Resnet50 model from the Keras models library" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CbJyUj1IoqF6" - }, - "outputs": [], - "source": [ - "jax_model = keras.applications.resnet.ResNet50(include_top=True, weights=\"imagenet\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B5yqkEEXwo13" - }, - "source": [ - "## Run the keras JAX model with the test input" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "IBHBetUqfhDA" - }, - "outputs": [], - "source": [ - "input_data = keras.applications.resnet50.preprocess_input(input_image)\n", - "jax_model_output = jax_model(input_data)\n", - "\n", - "decoded_preds = keras.applications.resnet.decode_predictions(jax_model_output, top=1)[\n", - " 0\n", - "][0]\n", - "print(\"Predicted class:\", decoded_preds[1])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HaUJGHGpw0KD" - }, - "source": [ - "## Save the Keras JAX model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "IZ4YqZLTrGc6" - }, - "outputs": [], - "source": [ - "saved_model_dir = \"resnet50_saved_model\"\n", - "jax_model.export(saved_model_dir)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bNSGGXfhw4uO" - }, - "source": [ - "## Convert to a TFLite model file" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MdJz2eKqsEhA" - }, - "outputs": [], - "source": [ - "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n", - "tflite_model = converter.convert()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NZ4DjfGSWS7O" - }, - "source": [ - "## Run using TFLite Runtime" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CtMSYAkwWWVm" - }, - "outputs": [], - "source": [ - "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", - "interpreter.allocate_tensors()\n", - "\n", - "input_details = interpreter.get_input_details()[0]\n", - "interpreter.set_tensor(input_details[\"index\"], input_data)\n", - "interpreter.invoke()\n", - "\n", - "output_details = interpreter.get_output_details()\n", - "output_data = interpreter.get_tensor(output_details[0][\"index\"])\n", - "\n", - "tfl_predicted_class_idx = keras.applications.resnet.decode_predictions(\n", - " output_data, top=1\n", - ")[0][0]\n", - "print(\"Predicted class:\", tfl_predicted_class_idx[1])" - ] - } - ], - "metadata": { - "colab": { - "name": "keras_jax_backend_to_tfl.ipynb", - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -}