java读取tensorflow中图像的分类模型
来源:互联网 发布:淘宝空间图片协议 编辑:程序博客网 时间:2024/06/03 21:21
经常在tensorflow中训练的图像模型,实际部署常见的用的是c++,实际java中也可以部署,在图像分类中,图像的预处理较为简单,只要做去均值和方差话(归一化),就可以使用,上午刚刚跟同事跑通了色情模型,
先看下依赖吧,其中tensorflow的版本必须是1.4.0以上,我看下1.2.1,1.3.0其中缺失了一个这个类,org.tensorflow.types.Uint8,这个不能运行成功,只有在1.4.0的版本才开始添加成功,
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.4.0</version>
</dependency>
直接看代码:
package com.meituan.test;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.types.UInt8;
public static void main(String[] args) {
String modelDir = "/Users/shuubiasahi/Downloads/inception5h";
String imageFile = "/Users/shuubiasahi/Desktop/timg.jpeg";
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "sexy_inception_v4_freeze.pb"));
List<String> labels =
readAllLinesOrExit(Paths.get(modelDir, "sexy_slim_labels.txt"));
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
Long time=System.currentTimeMillis();
try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
String.format("BEST MATCH: %s (%.2f%% likely)",
labels.get(bestLabelIdx),
labelProbabilities[bestLabelIdx] * 100f));
}
}
private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// Some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
//
// - The model was trained with images scaled to 224x224 pixels.
// - The colors, represented as R, G, B in 1-byte each were converted to
// float using (value - Mean)/Scale.
final int H = 299;
final int W = 299;
final float mean = 128f;
final float scale = 128f;
// Since the graph is being constructed once per execution here, we can use a constant for the
// input image. If the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
final Output<String> input = b.constant("input", imageBytes);
final Output<Float> output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), Float.class),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
}
private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
Tensor<Float> result =
s.runner().feed("input", image).fetch("InceptionV4/Logits/Predictions").run().get(0).expect(Float.class)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
}
private static int maxIndex(float[] probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
}
}
return best;
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
private static List<String> readAllLinesOrExit(Path path) {
try {
return Files.readAllLines(path, Charset.forName("UTF-8"));
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(0);
}
return null;
}
// In the fullness of time, equivalents of the methods of this class should be auto-generated from
// the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
// like Python, C++ and Go.
static class GraphBuilder {
GraphBuilder(Graph g) {
this.g = g;
}
Output<Float> div(Output<Float> x, Output<Float> y) {
return binaryOp("Div", x, y);
}
<T> Output<T> sub(Output<T> x, Output<T> y) {
return binaryOp("Sub", x, y);
}
<T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
return binaryOp3("ResizeBilinear", images, size);
}
<T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
return binaryOp3("ExpandDims", input, dim);
}
<T, U> Output<U> cast(Output<T> value, Class<U> type) {
DataType dtype = DataType.fromClass(type);
return g.opBuilder("Cast", "Cast")
.addInput(value)
.setAttr("DstT", dtype)
.build()
.<U>output(0);
}
Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
.<UInt8>output(0);
}
<T> Output<T> constant(String name, Object value, Class<T> type) {
try (Tensor<T> t = Tensor.<T>create(value, type)) {
return g.opBuilder("Const", name)
.setAttr("dtype", DataType.fromClass(type))
.setAttr("value", t)
.build()
.<T>output(0);
}
}
Output<String> constant(String name, byte[] value) {
return this.constant(name, value, String.class);
}
Output<Integer> constant(String name, int value) {
return this.constant(name, value, Integer.class);
}
Output<Integer> constant(String name, int[] value) {
return this.constant(name, value, Integer.class);
}
Output<Float> constant(String name, float value) {
return this.constant(name, value, Float.class);
}
private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private Graph g;
}
}
其中有两个关键的参数,input和output名称,这两个参数就是对应tensorflow中的x输入和y输出,保存这两个文件更改既可以
结果:
2017-11-07 19:55:01.650240: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA
BEST MATCH: sexy (99.88% likely)
阅读全文
0 0
- java读取tensorflow中图像的分类模型
- 如何重新训练Tensorflow图像分类模型
- Tensorflow 的安装和用InceptionV3训练新的图像分类模型
- Tensorflow学习(7)用别人训练好的模型进行图像分类
- 利用opencv3读取tensorflow model,对图像进行分类
- TensorFlow之CNN图像分类及模型保存与调用
- 谷歌开源图像分类工具TF-Slim,定义TensorFlow复杂模型
- tensorflow+图像分类使用的一些错误
- 6.TensorFlow模型的保存和读取
- Tensorflow 使用slim框架下的分类模型进行分类
- Tensorflow 使用slim框架下的分类模型进行分类
- Tensorflow中图像的预处理
- TensorFlow保存读取模型
- tensorflow训练好的模型中java调用
- java调用tensorflow模型进行图片分类识别
- tensorflow 图像分类实战解析
- TensorFlow-Slim图像分类库
- 基于tensorflow + Vgg16进行图像分类识别的实验
- 使用终端查看mysql数据中文出现乱码解决
- 用JavaScript将long型数据转换成date型或datetime型
- 第十周LeetCode算法题两道
- 用单例模式封装实现一个数据库类
- 创建登录界面
- java读取tensorflow中图像的分类模型
- 10.25第九周java作业
- [Redis学习笔记]-Redis 发布订阅(充当消息组件)
- mxnet 基础学习笔记(李沐课)
- 51nod 1272 最大距离
- SVN_SERVER的搭建
- 基于Mathematica的机器人仿真环境(机械臂篇)
- 多线程知识点总结二
- NVIDIA CUDA Compiler Driver NVCC