TensorFlow的训练模型在Android和Java的应用及调用
来源:互联网 发布:化妆品 知乎 编辑:程序博客网 时间:2024/05/29 16:00
环境:Windows 7
当我们开始学习编程的时候,第一件事往往是学习打印"Hello World"。就好比编程入门有Hello World,机器学习入门有MNIST。
MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:
它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,上面这四张图片的标签分别是5,0,4,1。
那我我们就将TensorFlow里的一个训练后的模型数据集,在Android里实现调用使用。
Tensorflow训练模型通常使用Python api编写,训练模型保存为二进制pb文件,内含数据集。
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip 这个是google给出的一个图像识别的训练模型集,供测试。
里面有2个文件:
第一个txt文件展示了这个pb训练模型可以识别的东西有哪些。
第二个pb文件为训练模型数据集,有51.3M大小。
那么我们接下来就是在android或Java里调用API使用他这个训练模型,实现图像识别功能。
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android 这个是TensorFlow官方的Demo源码。
Android想要使用要编译so,毕竟是跨平台调用。
jni在官方Demo里也附带了。
Android和TensorFlow调用API的aar库可以在gradle里引用:
compile 'org.tensorflow:tensorflow-android:+'
基本结构:
基本API调用训练模型如下代码类似:
TensorFlowInferenceInterface tfi = new TensorFlowInferenceInterface("F:/tf_mode/output_graph.pb","imageType"); final Operation operation = tfi.graphOperation("y_conv_add"); Output output = operation.output(0); Shape shape = output.shape(); final int numClasses = (int) shape.size(1);
主要的类就是TensorFlowInferenceInterface 、Operation。
那么接下来把官方Demo的这个类调用给出:
他这个是Android的Assets目录读取训练模型, 从
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
这句可以看出。
那么我们可以根据实际训练模型pb文件的位置进行修改引用。
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.Licensed under the Apache License, Version 2.0 (the "License");you may not use this file except in compliance with the License.You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0Unless required by applicable law or agreed to in writing, softwaredistributed under the License is distributed on an "AS IS" BASIS,WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.See the License for the specific language governing permissions andlimitations under the License.==============================================================================*/package org.tensorflow.demo;import android.content.res.AssetManager;import android.graphics.Bitmap;import android.os.Trace;import android.util.Log;import java.io.BufferedReader;import java.io.IOException;import java.io.InputStreamReader;import java.util.ArrayList;import java.util.Comparator;import java.util.List;import java.util.PriorityQueue;import java.util.Vector;import org.tensorflow.Operation;import org.tensorflow.contrib.android.TensorFlowInferenceInterface;/** A classifier specialized to label images using TensorFlow. */public class TensorFlowImageClassifier implements Classifier { private static final String TAG = "TensorFlowImageClassifier"; // Only return this many results with at least this confidence. private static final int MAX_RESULTS = 3; private static final float THRESHOLD = 0.1f; // Config values. private String inputName; private String outputName; private int inputSize; private int imageMean; private float imageStd; // Pre-allocated buffers. private Vector<String> labels = new Vector<String>(); private int[] intValues; private float[] floatValues; private float[] outputs; private String[] outputNames; private boolean logStats = false; private TensorFlowInferenceInterface inferenceInterface; private TensorFlowImageClassifier() {} /** * Initializes a native TensorFlow session for classifying images. * * @param assetManager The asset manager to be used to load assets. * @param modelFilename The filepath of the model GraphDef protocol buffer. * @param labelFilename The filepath of label file for classes. * @param inputSize The input size. A square image of inputSize x inputSize is assumed. * @param imageMean The assumed mean of the image values. * @param imageStd The assumed std of the image values. * @param inputName The label of the image input node. * @param outputName The label of the output node. * @throws IOException */ public static Classifier create( AssetManager assetManager, String modelFilename, String labelFilename, int inputSize, int imageMean, float imageStd, String inputName, String outputName) { TensorFlowImageClassifier c = new TensorFlowImageClassifier(); c.inputName = inputName; c.outputName = outputName; // Read the label names into memory. // TODO(andrewharp): make this handle non-assets. String actualFilename = labelFilename.split("file:///android_asset/")[1]; Log.i(TAG, "Reading labels from: " + actualFilename); BufferedReader br = null; try { br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename))); String line; while ((line = br.readLine()) != null) { c.labels.add(line); } br.close(); } catch (IOException e) { throw new RuntimeException("Problem reading label file!" , e); } c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); // The shape of the output is [N, NUM_CLASSES], where N is the batch size. final Operation operation = c.inferenceInterface.graphOperation(outputName); final int numClasses = (int) operation.output(0).shape().size(1); Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses); // Ideally, inputSize could have been retrieved from the shape of the input operation. Alas, // the placeholder node for input in the graphdef typically used does not specify a shape, so it // must be passed in as a parameter. c.inputSize = inputSize; c.imageMean = imageMean; c.imageStd = imageStd; // Pre-allocate buffers. c.outputNames = new String[] {outputName}; c.intValues = new int[inputSize * inputSize]; c.floatValues = new float[inputSize * inputSize * 3]; c.outputs = new float[numClasses]; return c; } @Override public List<Recognition> recognizeImage(final Bitmap bitmap) { // Log this method so that it can be analyzed with systrace. Trace.beginSection("recognizeImage"); Trace.beginSection("preprocessBitmap"); // Preprocess the image data from 0-255 int to normalized float based // on the provided parameters. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd; } Trace.endSection(); // Copy the input data into TensorFlow. Trace.beginSection("feed"); inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); Trace.endSection(); // Run the inference call. Trace.beginSection("run"); inferenceInterface.run(outputNames, logStats); Trace.endSection(); // Copy the output Tensor back into the output array. Trace.beginSection("fetch"); inferenceInterface.fetch(outputName, outputs); Trace.endSection(); // Find the best classifications. PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>( 3, new Comparator<Recognition>() { @Override public int compare(Recognition lhs, Recognition rhs) { // Intentionally reversed to put high confidence at the head of the queue. return Float.compare(rhs.getConfidence(), lhs.getConfidence()); } }); for (int i = 0; i < outputs.length; ++i) { if (outputs[i] > THRESHOLD) { pq.add( new Recognition( "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null)); } } final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); for (int i = 0; i < recognitionsSize; ++i) { recognitions.add(pq.poll()); } Trace.endSection(); // "recognizeImage" return recognitions; } @Override public void enableStatLogging(boolean logStats) { this.logStats = logStats; } @Override public String getStatString() { return inferenceInterface.getStatString(); } @Override public void close() { inferenceInterface.close(); }}
新版本的api改了下,那我给出旧版本的Android Studio版本的Demo。
https://github.com/Nilhcem/tensorflow-classifier-android
这个是国外的一个开发者编译好so库的一个旧的Demo调用版本。大家可以参考下,和新版使用方法大同小异。
- TensorFlow的训练模型在Android和Java的应用及调用
- TensorFlow的训练模型在Android和Java的应用及调用
- tensorflow训练好的模型中java调用
- C++调用tensorflow 训练好的模型
- c++调用python训练的tensorflow模型
- tensorflow训练的模型在java中的使用
- tensorflow训练的模型在java中的使用
- tensorflow 的模型保存和调用
- 将tensorflow训练好的模型移植到android
- 将tensorflow训练好的模型移植到android
- 将tensorflow训练好的模型移植到android
- 将tensorflow训练好的模型移植到android
- 将tensorflow训练好的模型移植到android
- java加载tensorflow训练好的模型部署成service
- keras基于theano和tensorflow训练的模型相互转换
- TensorFlow 训练好模型参数的保存和恢复代码
- TensorFlow在MNIST中的应用-训练过程的可视化
- Stanford NER模型使用,训练自己的NER模型,终端使用和java调用
- Git基本使用流程
- 翻译需要截图
- 走穿23种设计模式-8适配器模式详解
- 关于数组的方法整理,求补充!!!
- Linux CentOS7 安装 配置 Redis
- TensorFlow的训练模型在Android和Java的应用及调用
- JDBC
- 第三题
- BZOJ 1999 [Noip2007]树网的核(2282 [Sdoi2011]消防)
- g++: command not found的解决 G++没有装或却没有更新 以下方法都可以试试: centos: yum -y update gcc yum -y install gcc+
- 玲珑OJ 1171
- C/C++:从命令行获取参数
- 指针变量作为函数参数问题
- 每日一句