/* * Copyright 2017 The Android Things Samples Authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.example.androidthings.imageclassifier.classifier; import android.content.Context; import android.content.res.AssetManager; import android.graphics.Bitmap; import android.util.Log; import org.tensorflow.contrib.android.TensorFlowInferenceInterface; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.PriorityQueue; /** * A classifier specialized to label images using TensorFlow. */ public class TensorFlowImageClassifier implements Classifier { private static final String TAG = "TFImageClassifier"; // These are the settings for the original v1 Inception model. If you want to // use a model that's been produced from the TensorFlow for Poets codelab, // you'll need to set IMAGE_SIZE = 299, IMAGE_MEAN = 128, IMAGE_STD = 128, // INPUT_NAME = "Mul:0", and OUTPUT_NAME = "final_result:0". // You'll also need to update the MODEL_FILE and LABEL_FILE paths to point to // the ones you produced. public static final int INPUT_SIZE = 224; // Note: the actual number of classes for Inception is 1001, but the output layer size is 1008. private static final int NUM_CLASSES = 1008; private static final int IMAGE_MEAN = 117; private static final float IMAGE_STD = 1; private static final String INPUT_NAME = "input:0"; private static final String OUTPUT_NAME = "output:0"; private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb"; private static final String LABEL_FILE = "file:///android_asset/imagenet_comp_graph_label_strings.txt"; // Only return this many results with at least this confidence. private static final int MAX_RESULTS = 3; private static final float THRESHOLD = 0.1f; // Config values. private String inputName; private String outputName; private int inputSize; private int imageMean; private float imageStd; // Pre-allocated buffers. private List<String> labels; private float[] floatValues; private float[] outputs; private String[] outputNames; private int[] intValues; private TensorFlowInferenceInterface inferenceInterface; /** * Initializes a native TensorFlow session for classifying images. * * @param context The context from which to get the asset manager to be used to load assets. */ public TensorFlowImageClassifier(Context context) { this(context.getAssets(), MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME, OUTPUT_NAME); } /** * Initializes a native TensorFlow session for classifying images. * * @param assetManager The asset manager to be used to load assets. * @param modelFilename The filepath of the model GraphDef protocol buffer. * @param labelFilename The filepath of label file for classes. * @param numClasses The number of classes output by the model. * @param inputSize The input size. A square image of inputSize x inputSize is assumed. * @param imageMean The assumed mean of the image values. * @param imageStd The assumed std of the image values. * @param inputName The label of the image input node. * @param outputName The label of the output node. */ public TensorFlowImageClassifier(AssetManager assetManager, String modelFilename, String labelFilename, int numClasses, int inputSize, int imageMean, float imageStd, String inputName, String outputName) { this.inputName = inputName; this.outputName = outputName; // Read the label names into memory. String actualFilename = labelFilename.split("file:///android_asset/")[1]; this.labels = readLabels(assetManager, actualFilename); Log.i(TAG, "Read " + labels.size() + ", " + numClasses + " specified"); this.inputSize = inputSize; this.imageMean = imageMean; this.imageStd = imageStd; // Pre-allocate buffers. this.outputNames = new String[]{outputName}; this.floatValues = new float[inputSize * inputSize * 3]; this.outputs = new float[numClasses]; this.intValues = new int[inputSize * inputSize]; this.inferenceInterface = new TensorFlowInferenceInterface(); this.inferenceInterface.initializeTensorFlow(assetManager, modelFilename); } private ArrayList<String> readLabels(AssetManager assetManager, String filename) { ArrayList<String> result = new ArrayList<>(); try { BufferedReader br = null; br = new BufferedReader(new InputStreamReader(assetManager.open(filename))); String line; while ((line = br.readLine()) != null) { result.add(line); } br.close(); } catch (IOException ex) { throw new IllegalStateException("Cannot read labels from " + filename); } return result; } @Override public List<Recognition> recognizeImage(final Bitmap bitmap) { bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); // Preprocess the image data from 0-255 int to normalized float based // on the provided parameters. for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; floatValues[i * 3] = (((val >> 16) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd; } // Copy the input data into TensorFlow. inferenceInterface.fillNodeFloat( inputName, new int[]{1, inputSize, inputSize, 3}, floatValues); // Run the inference call. inferenceInterface.runInference(outputNames); // Copy the output Tensor back into the output array. inferenceInterface.readNodeFloat(outputName, outputs); // Find the best classifications. PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(3, new Comparator<Recognition>() { @Override public int compare(Recognition lhs, Recognition rhs) { // Intentionally reversed to put high confidence at the head of the queue. return Float.compare(rhs.getConfidence(), lhs.getConfidence()); } }); for (int i = 0; i < outputs.length; ++i) { if (outputs[i] > THRESHOLD) { pq.add(new Recognition("" + i, labels.get(i), outputs[i], null)); } } ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); for (int i = 0; i < recognitionsSize; ++i) { recognitions.add(pq.poll()); } return recognitions; } @Override public void close() { inferenceInterface.close(); } }