java web端调用tensorflow模型
来源:互联网 发布:上海地铁免费查询软件 编辑:程序博客网 时间:2024/06/06 02:31
公司想做个web端的识别功能,可网上例子很少,官网可以找到例子:java代码
不过拿来用可能会出现问题,我们web端用的是1.3.0可是官网已经是1.4.0了,代码又不一样了。。。为什么要说又
可以通过下载老版本的源码找到例子代码
还是贴一下1.3.0代码吧
public class LabelImage { public static void main(String[] args) { String modelDir = "C:\\sts"; String imageFile = "C:\\sts\\timg.jpg"; byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb")); List<String> labels = readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt")); byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile)); try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { float[] labelProbabilities = executeInceptionGraph(graphDef, image); int bestLabelIdx = maxIndex(labelProbabilities); System.out.println( String.format( "BEST MATCH: %s (%.2f%% likely)", labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f)); } } private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { try (Graph g = new Graph()) { GraphBuilder b = new GraphBuilder(g); final int H = 224; final int W = 224; final float mean = 117f; final float scale = 1f; final Output input = b.constant("input", imageBytes); final Output output = b.div( b.sub( b.resizeBilinear( b.expandDims( b.cast(b.decodeJpeg(input, 3), DataType.FLOAT), b.constant("make_batch", 0)), b.constant("size", new int[] {H, W})), b.constant("mean", mean)), b.constant("scale", scale)); try (Session s = new Session(g)) { return s.runner().fetch(output.op().name()).run().get(0); } } } private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) { try (Graph g = new Graph()) { g.importGraphDef(graphDef); try (Session s = new Session(g); Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) { final long[] rshape = result.shape(); if (result.numDimensions() != 2 || rshape[0] != 1) { throw new RuntimeException( String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape))); } int nlabels = (int) rshape[1]; return result.copyTo(new float[1][nlabels])[0]; } } } private static int maxIndex(float[] probabilities) { int best = 0; for (int i = 1; i < probabilities.length; ++i) { if (probabilities[i] > probabilities[best]) { best = i; } } return best; } private static byte[] readAllBytesOrExit(Path path) { try { return Files.readAllBytes(path); } catch (IOException e) { System.err.println("Failed to read [" + path + "]: " + e.getMessage()); System.exit(1); } return null; } private static List<String> readAllLinesOrExit(Path path) { try { return Files.readAllLines(path, Charset.forName("UTF-8")); } catch (IOException e) { System.err.println("Failed to read [" + path + "]: " + e.getMessage()); System.exit(0); } return null; } static class GraphBuilder { GraphBuilder(Graph g) { this.g = g; } Output div(Output x, Output y) { return binaryOp("Div", x, y); } Output sub(Output x, Output y) { return binaryOp("Sub", x, y); } Output resizeBilinear(Output images, Output size) { return binaryOp("ResizeBilinear", images, size); } Output expandDims(Output input, Output dim) { return binaryOp("ExpandDims", input, dim); } Output cast(Output value, DataType dtype) { return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0); } Output decodeJpeg(Output contents, long channels) { return g.opBuilder("DecodeJpeg", "DecodeJpeg") .addInput(contents) .setAttr("channels", channels) .build() .output(0); } Output constant(String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name) .setAttr("dtype", t.dataType()) .setAttr("value", t) .build() .output(0); } } private Output binaryOp(String type, Output in1, Output in2) { return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0); } private Graph g; }}
虽然有例子,可是公司的要求高一些,要实现跟踪的ssd模型,没找到例子,搞不明白怎么调用,然后就交给我魔改了,代码完全复制android的tensorflow.jar里的代码
TensorFlowInferenceInterface.java,类名是不是很熟,我基本就是复制粘贴
public class TensorFlowInferenceInterface { private final String modelName; private final Graph g; private final Session sess; private Session.Runner runner; private List<String> feedNames = new ArrayList(); private List<Tensor> feedTensors = new ArrayList(); private List<String> fetchNames = new ArrayList(); private List<Tensor> fetchTensors = new ArrayList(); public TensorFlowInferenceInterface(String var2) { this.modelName = var2; this.g = new Graph(); this.sess = new Session(this.g); this.runner = this.sess.runner(); Object var4 = null; try { var4 = new FileInputStream(var2); } catch (IOException var8) { throw new RuntimeException("Failed to load model from '" + var2 + "'", var8); } try { byte[] var10 = new byte[((InputStream)var4).available()]; int var6 = ((InputStream)var4).read(var10); if(var6 != var10.length) { throw new IOException("read error: read only " + var6 + " of the graph, expected to read " + var10.length); } else { this.loadGraph(var10, this.g); ((InputStream)var4).close(); } } catch (IOException var7) { throw new RuntimeException("Failed to load model from '" + var2 + "'", var7); } } public TensorFlowInferenceInterface(InputStream var1) { this.modelName = ""; this.g = new Graph(); this.sess = new Session(this.g); this.runner = this.sess.runner(); try { int var2 = var1.available() > 16384?var1.available():16384; ByteArrayOutputStream var3 = new ByteArrayOutputStream(var2); byte[] var5 = new byte[16384]; int var4; while((var4 = var1.read(var5, 0, var5.length)) != -1) { var3.write(var5, 0, var4); } byte[] var6 = var3.toByteArray(); this.loadGraph(var6, this.g); } catch (IOException var7) { throw new RuntimeException("Failed to load model from the input stream", var7); } } public TensorFlowInferenceInterface(Graph var1) { this.modelName = ""; this.g = var1; this.sess = new Session(var1); this.runner = this.sess.runner(); } public void run(String[] var1) { this.closeFetches(); String[] var3 = var1; int var4 = var1.length; for(int var5 = 0; var5 < var4; ++var5) { String var6 = var3[var5]; this.fetchNames.add(var6); TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6); this.runner.fetch(var7.name, var7.outputIndex); } try { this.fetchTensors = this.runner.run(); } catch (RuntimeException var11) { throw var11; } finally { this.closeFeeds(); this.runner = this.sess.runner(); } } public Graph graph() { return this.g; } public Operation graphOperation(String var1) { Operation var2 = this.g.operation(var1); if(var2 == null) { throw new RuntimeException("Node '" + var1 + "' does not exist in model '" + this.modelName + "'"); } else { return var2; } } public void close() { this.closeFeeds(); this.closeFetches(); this.sess.close(); this.g.close(); } protected void finalize() throws Throwable { try { this.close(); } finally { super.finalize(); } } public void feed(String var1, float[] var2, long... var3) { this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2))); } public void feed(String var1, int[] var2, long... var3) { this.addFeed(var1, Tensor.create(var3, IntBuffer.wrap(var2))); } public void feed(String var1, long[] var2, long... var3) { this.addFeed(var1, Tensor.create(var3, LongBuffer.wrap(var2))); } public void feed(String var1, double[] var2, long... var3) { this.addFeed(var1, Tensor.create(var3, DoubleBuffer.wrap(var2))); } public void feed(String var1, byte[] var2, long... var3) { this.addFeed(var1, Tensor.create(DataType.UINT8, var3, ByteBuffer.wrap(var2))); } public void feedString(String var1, byte[] var2) { this.addFeed(var1, Tensor.create(var2)); } public void feedString(String var1, byte[][] var2) { this.addFeed(var1, Tensor.create(var2)); } public void feed(String var1, FloatBuffer var2, long... var3) { this.addFeed(var1, Tensor.create(var3, var2)); } public void feed(String var1, IntBuffer var2, long... var3) { this.addFeed(var1, Tensor.create(var3, var2)); } public void feed(String var1, LongBuffer var2, long... var3) { this.addFeed(var1, Tensor.create(var3, var2)); } public void feed(String var1, DoubleBuffer var2, long... var3) { this.addFeed(var1, Tensor.create(var3, var2)); } public void feed(String var1, ByteBuffer var2, long... var3) { this.addFeed(var1, Tensor.create(DataType.UINT8, var3, var2)); } public void fetch(String var1, float[] var2) { this.fetch(var1, FloatBuffer.wrap(var2)); } public void fetch(String var1, int[] var2) { this.fetch(var1, IntBuffer.wrap(var2)); } public void fetch(String var1, long[] var2) { this.fetch(var1, LongBuffer.wrap(var2)); } public void fetch(String var1, double[] var2) { this.fetch(var1, DoubleBuffer.wrap(var2)); } public void fetch(String var1, byte[] var2) { this.fetch(var1, ByteBuffer.wrap(var2)); } public void fetch(String var1, FloatBuffer var2) { this.getTensor(var1).writeTo(var2); } public void fetch(String var1, IntBuffer var2) { this.getTensor(var1).writeTo(var2); } public void fetch(String var1, LongBuffer var2) { this.getTensor(var1).writeTo(var2); } public void fetch(String var1, DoubleBuffer var2) { this.getTensor(var1).writeTo(var2); } public void fetch(String var1, ByteBuffer var2) { this.getTensor(var1).writeTo(var2); } private void loadGraph(byte[] var1, Graph var2) throws IOException { try { var2.importGraphDef(var1); } catch (IllegalArgumentException var7) { throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage()); } } private void addFeed(String var1, Tensor var2) { TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1); this.runner.feed(var3.name, var3.outputIndex, var2); this.feedNames.add(var1); this.feedTensors.add(var2); } private Tensor getTensor(String var1) { int var2 = 0; for(Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) { String var4 = (String)var3.next(); if(var4.equals(var1)) { return (Tensor)this.fetchTensors.get(var2); } } throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read"); } private void closeFeeds() { Iterator var1 = this.feedTensors.iterator(); while(var1.hasNext()) { Tensor var2 = (Tensor)var1.next(); var2.close(); } this.feedTensors.clear(); this.feedNames.clear(); } private void closeFetches() { Iterator var1 = this.fetchTensors.iterator(); while(var1.hasNext()) { Tensor var2 = (Tensor)var1.next(); var2.close(); } this.fetchTensors.clear(); this.fetchNames.clear(); } private static class TensorId { String name; int outputIndex; private TensorId() { } public static TensorFlowInferenceInterface.TensorId parse(String var0) { TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId(); int var2 = var0.lastIndexOf(58); if(var2 < 0) { var1.outputIndex = 0; var1.name = var0; return var1; } else { try { var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1)); var1.name = var0.substring(0, var2); } catch (NumberFormatException var4) { var1.outputIndex = 0; var1.name = var0; } return var1; } } }}
Classifier.java
public interface Classifier { public class Recognition { private final String id; private final String title; private final Float confidence; private float left,top,right,bottom; public Recognition( final String id, final String title, final Float confidence, float left,float top,float right,float bottom) { this.id = id; this.title = title; this.confidence = confidence; this.left = left; this.top = top; this.right = right; this.bottom = bottom; } public String getId() { return id; } public String getTitle() { return title; } public Float getConfidence() { return confidence; } public float getLeft() { return left; } public void setLeft(float left) { this.left = left; } public float getTop() { return top; } public void setTop(float top) { this.top = top; } public float getRight() { return right; } public void setRight(float right) { this.right = right; } public float getBottom() { return bottom; } public void setBottom(float bottom) { this.bottom = bottom; } @Override public String toString() { String resultString = ""; if (id != null) { resultString += "[" + id + "] "; } if (title != null) { resultString += title + " "; } if (confidence != null) { resultString += String.format("(%.1f%%) ", confidence * 100.0f); } if (left != 0) { resultString += left + " "; } if (top != 0) { resultString += top + " "; } if (right != 0) { resultString += right + " "; } if (bottom != 0) { resultString += bottom + " "; } return resultString.trim(); } } List<Recognition> recognizeImage(int[] byteValues); void close();}
TensorFlowObjectDetectionAPIModel.java
public class TensorFlowObjectDetectionAPIModel implements Classifier{ private static final int MAX_RESULTS = 100; private String inputName; private int inputSize; private Vector<String> labels = new Vector<String>(); private byte[] byteValues; private float[] outputLocations; private float[] outputScores; private float[] outputClasses; private float[] outputNumDetections; private String[] outputNames; private TensorFlowInferenceInterface inferenceInterface; public static Classifier create( final String modelFilename, final String labelFilename, final int inputSize) throws IOException { final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel(); InputStream labelsInput = new FileInputStream(labelFilename); BufferedReader br = null; br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { d.labels.add(line); } br.close(); d.inferenceInterface = new TensorFlowInferenceInterface(modelFilename); final Graph g = d.inferenceInterface.graph(); d.inputName = "image_tensor"; final Operation inputOp = g.operation(d.inputName); if (inputOp == null) { throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); } d.inputSize = inputSize; final Operation outputOp1 = g.operation("detection_scores"); if (outputOp1 == null) { throw new RuntimeException("Failed to find output Node 'detection_scores'"); } final Operation outputOp2 = g.operation("detection_boxes"); if (outputOp2 == null) { throw new RuntimeException("Failed to find output Node 'detection_boxes'"); } final Operation outputOp3 = g.operation("detection_classes"); if (outputOp3 == null) { throw new RuntimeException("Failed to find output Node 'detection_classes'"); } d.outputNames = new String[] {"detection_boxes", "detection_scores", "detection_classes", "num_detections"}; d.byteValues = new byte[d.inputSize * d.inputSize * 3]; d.outputScores = new float[MAX_RESULTS]; d.outputLocations = new float[MAX_RESULTS * 4]; d.outputClasses = new float[MAX_RESULTS]; d.outputNumDetections = new float[1]; return d; } private TensorFlowObjectDetectionAPIModel() {} @Override public List<Recognition> recognizeImage(int[] intValues) { for (int i = 0; i < intValues.length; ++i) { byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF); byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF); byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF); } inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3); inferenceInterface.run(outputNames); outputLocations = new float[MAX_RESULTS * 4]; outputScores = new float[MAX_RESULTS]; outputClasses = new float[MAX_RESULTS]; outputNumDetections = new float[1]; inferenceInterface.fetch(outputNames[0], outputLocations); inferenceInterface.fetch(outputNames[1], outputScores); inferenceInterface.fetch(outputNames[2], outputClasses); inferenceInterface.fetch(outputNames[3], outputNumDetections); final PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(1,new Comparator<Recognition>() { public int compare(final Recognition lhs, final Recognition rhs) { return Float.compare(rhs.getConfidence(), lhs.getConfidence()); } }); for (int i = 0; i < outputScores.length; ++i) { float left = outputLocations[4 * i + 1] * inputSize; float top = outputLocations[4 * i] * inputSize; float right = outputLocations[4 * i + 3] * inputSize; float bottom = outputLocations[4 * i + 2] * inputSize; pq.add( new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], left,top,right,bottom)); } final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { recognitions.add(pq.poll()); } return recognitions; } @Override public void close() { inferenceInterface.close(); }}
DetectionImage.java这个是实现类只需要复制实现代码
public class DetectionImage { public static void main(String[] args) { int input_size = 300; Classifier d = null; try { d = TensorFlowObjectDetectionAPIModel.create("C:\\sts\\ssd_mobilenet_v1_android_export.pb", "C:\\sts\\coco_labels_list.txt", input_size); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } if(d != null){ File file = new File("C:\\sts\\person.jpg"); Image img = null; try { img = ImageIO.read(file); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } if(img != null){ int width = img.getWidth(null); int height = img.getHeight(null); BufferedImage image = new BufferedImage(input_size, input_size, BufferedImage.TYPE_INT_RGB); image.getGraphics().drawImage(img, 0, 0, input_size, input_size,0,0,width,height,null); int[] rgbs = new int[input_size * input_size]; image.getRGB(0, 0, input_size, input_size, rgbs, 0, input_size); List<Recognition> results = d.recognizeImage(rgbs); for (Recognition result : results) { System.out.println(result.toString()); } } } }}
输出结果
[0] person (99.1%) 113.1657 27.434679 183.88785 296.03595
没想到java用起来还是蛮6的,果然android不行了还可以转行java。。。还是程序猿。。。
阅读全文
0 0
- java web端调用tensorflow模型
- java调用tensorflow模型进行图片分类识别
- tensorflow训练好的模型中java调用
- java调用tensorflow
- tensorflow 的模型保存和调用
- Tensorflow C++ 编译和调用图模型
- Tensorflow C++ 编译和调用图模型
- C++调用tensorflow 训练好的模型
- c++调用python训练的tensorflow模型
- TensorFlow的训练模型在Android和Java的应用及调用
- TensorFlow的训练模型在Android和Java的应用及调用
- TensorFlow之CNN图像分类及模型保存与调用
- Win10 Vs2017 环境下 C ++调用tensorflow模型
- Java调用Python写的tensorflow函数
- TensorFlow学习笔记10——TensorFlow保存和调用模型遇到的问题
- java 调用 Web Service
- JAVA 调用Web Service
- java调用java web service
- 对人工神经网络的隐式行为进行可视化
- 数据分析图的十大错误,你占了几个?
- 数据可视化:常用图表使用总结
- 22个免费的数据可视化和分析工具推荐
- scikit-learn SVM
- java web端调用tensorflow模型
- Hashedcubes: 对于大数据的简洁,低存耗,实时的可视探索
- 通过可视化数据分析提升测试质量
- 软考分页存储求物理地址公式
- 路由侧边栏
- bzoj 3563: DZY Loves Chinese
- 30 个最好的数据可视化工具推荐
- input file multiple 配合springmvc实现多文件上传
- 一图看懂:2017年中国自媒体从业人员生存状况