java读取tensorflow中图像的分类模型

来源:互联网 发布:淘宝空间图片协议 编辑:程序博客网 时间:2024/06/03 21:21

经常在tensorflow中训练的图像模型,实际部署常见的用的是c++,实际java中也可以部署,在图像分类中,图像的预处理较为简单,只要做去均值和方差话(归一化),就可以使用,上午刚刚跟同事跑通了色情模型,
先看下依赖吧,其中tensorflow的版本必须是1.4.0以上,我看下1.2.1,1.3.0其中缺失了一个这个类,org.tensorflow.types.Uint8,这个不能运行成功,只有在1.4.0的版本才开始添加成功,

  1. <dependency>
  2. <groupId>org.tensorflow</groupId>
  3. <artifactId>tensorflow</artifactId>
  4. <version>1.4.0</version>
  5. </dependency>


直接看代码:

  1. package com.meituan.test;
  2. import java.io.IOException;
  3. import java.io.PrintStream;
  4. import java.nio.charset.Charset;
  5. import java.nio.file.Files;
  6. import java.nio.file.Path;
  7. import java.nio.file.Paths;
  8. import java.util.Arrays;
  9. import java.util.List;
  10. import org.tensorflow.DataType;
  11. import org.tensorflow.Graph;
  12. import org.tensorflow.Output;
  13. import org.tensorflow.Session;
  14. import org.tensorflow.Tensor;
  15. import org.tensorflow.TensorFlow;
  16. import org.tensorflow.types.UInt8;
  17. public static void main(String[] args) {
  18. String modelDir = "/Users/shuubiasahi/Downloads/inception5h";
  19. String imageFile = "/Users/shuubiasahi/Desktop/timg.jpeg";
  20. byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "sexy_inception_v4_freeze.pb"));
  21. List<String> labels =
  22. readAllLinesOrExit(Paths.get(modelDir, "sexy_slim_labels.txt"));
  23. byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
  24. Long time=System.currentTimeMillis();
  25. try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
  26. float[] labelProbabilities = executeInceptionGraph(graphDef, image);
  27. int bestLabelIdx = maxIndex(labelProbabilities);
  28. System.out.println(
  29. String.format("BEST MATCH: %s (%.2f%% likely)",
  30. labels.get(bestLabelIdx),
  31. labelProbabilities[bestLabelIdx] * 100f));
  32. }
  33. }
  34. private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
  35. try (Graph g = new Graph()) {
  36. GraphBuilder b = new GraphBuilder(g);
  37. // Some constants specific to the pre-trained model at:
  38. // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
  39. //
  40. // - The model was trained with images scaled to 224x224 pixels.
  41. // - The colors, represented as R, G, B in 1-byte each were converted to
  42. // float using (value - Mean)/Scale.
  43. final int H = 299;
  44. final int W = 299;
  45. final float mean = 128f;
  46. final float scale = 128f;
  47. // Since the graph is being constructed once per execution here, we can use a constant for the
  48. // input image. If the graph were to be re-used for multiple input images, a placeholder would
  49. // have been more appropriate.
  50. final Output<String> input = b.constant("input", imageBytes);
  51. final Output<Float> output =
  52. b.div(
  53. b.sub(
  54. b.resizeBilinear(
  55. b.expandDims(
  56. b.cast(b.decodeJpeg(input, 3), Float.class),
  57. b.constant("make_batch", 0)),
  58. b.constant("size", new int[] {H, W})),
  59. b.constant("mean", mean)),
  60. b.constant("scale", scale));
  61. try (Session s = new Session(g)) {
  62. return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
  63. }
  64. }
  65. }
  66. private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
  67. try (Graph g = new Graph()) {
  68. g.importGraphDef(graphDef);
  69. try (Session s = new Session(g);
  70. Tensor<Float> result =
  71. s.runner().feed("input", image).fetch("InceptionV4/Logits/Predictions").run().get(0).expect(Float.class)) {
  72. final long[] rshape = result.shape();
  73. if (result.numDimensions() != 2 || rshape[0] != 1) {
  74. throw new RuntimeException(
  75. String.format(
  76. "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
  77. Arrays.toString(rshape)));
  78. }
  79. int nlabels = (int) rshape[1];
  80. return result.copyTo(new float[1][nlabels])[0];
  81. }
  82. }
  83. }
  84. private static int maxIndex(float[] probabilities) {
  85. int best = 0;
  86. for (int i = 1; i < probabilities.length; ++i) {
  87. if (probabilities[i] > probabilities[best]) {
  88. best = i;
  89. }
  90. }
  91. return best;
  92. }
  93. private static byte[] readAllBytesOrExit(Path path) {
  94. try {
  95. return Files.readAllBytes(path);
  96. } catch (IOException e) {
  97. System.err.println("Failed to read [" + path + "]: " + e.getMessage());
  98. System.exit(1);
  99. }
  100. return null;
  101. }
  102. private static List<String> readAllLinesOrExit(Path path) {
  103. try {
  104. return Files.readAllLines(path, Charset.forName("UTF-8"));
  105. } catch (IOException e) {
  106. System.err.println("Failed to read [" + path + "]: " + e.getMessage());
  107. System.exit(0);
  108. }
  109. return null;
  110. }
  111. // In the fullness of time, equivalents of the methods of this class should be auto-generated from
  112. // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
  113. // like Python, C++ and Go.
  114. static class GraphBuilder {
  115. GraphBuilder(Graph g) {
  116. this.g = g;
  117. }
  118. Output<Float> div(Output<Float> x, Output<Float> y) {
  119. return binaryOp("Div", x, y);
  120. }
  121. <T> Output<T> sub(Output<T> x, Output<T> y) {
  122. return binaryOp("Sub", x, y);
  123. }
  124. <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
  125. return binaryOp3("ResizeBilinear", images, size);
  126. }
  127. <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
  128. return binaryOp3("ExpandDims", input, dim);
  129. }
  130. <T, U> Output<U> cast(Output<T> value, Class<U> type) {
  131. DataType dtype = DataType.fromClass(type);
  132. return g.opBuilder("Cast", "Cast")
  133. .addInput(value)
  134. .setAttr("DstT", dtype)
  135. .build()
  136. .<U>output(0);
  137. }
  138. Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
  139. return g.opBuilder("DecodeJpeg", "DecodeJpeg")
  140. .addInput(contents)
  141. .setAttr("channels", channels)
  142. .build()
  143. .<UInt8>output(0);
  144. }
  145. <T> Output<T> constant(String name, Object value, Class<T> type) {
  146. try (Tensor<T> t = Tensor.<T>create(value, type)) {
  147. return g.opBuilder("Const", name)
  148. .setAttr("dtype", DataType.fromClass(type))
  149. .setAttr("value", t)
  150. .build()
  151. .<T>output(0);
  152. }
  153. }
  154. Output<String> constant(String name, byte[] value) {
  155. return this.constant(name, value, String.class);
  156. }
  157. Output<Integer> constant(String name, int value) {
  158. return this.constant(name, value, Integer.class);
  159. }
  160. Output<Integer> constant(String name, int[] value) {
  161. return this.constant(name, value, Integer.class);
  162. }
  163. Output<Float> constant(String name, float value) {
  164. return this.constant(name, value, Float.class);
  165. }
  166. private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
  167. return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
  168. }
  169. private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
  170. return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
  171. }
  172. private Graph g;
  173. }
  174. }


其中有两个关键的参数,input和output名称,这两个参数就是对应tensorflow中的x输入和y输出,保存这两个文件更改既可以


结果:

  1. 2017-11-07 19:55:01.650240: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA
  2. BEST MATCH: sexy (99.88% likely)


原创粉丝点击