java web端调用tensorflow模型

来源:互联网 发布:上海地铁免费查询软件 编辑:程序博客网 时间:2024/06/06 02:31

公司想做个web端的识别功能,可网上例子很少,官网可以找到例子:java代码 

不过拿来用可能会出现问题,我们web端用的是1.3.0可是官网已经是1.4.0了,代码又不一样了。。。为什么要说又

可以通过下载老版本的源码找到例子代码

还是贴一下1.3.0代码吧

public class LabelImage {    public static void main(String[] args) {        String modelDir = "C:\\sts";        String imageFile = "C:\\sts\\timg.jpg";        byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));        List<String> labels =                readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));        byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));        try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {            float[] labelProbabilities = executeInceptionGraph(graphDef, image);            int bestLabelIdx = maxIndex(labelProbabilities);            System.out.println(                    String.format(                            "BEST MATCH: %s (%.2f%% likely)",                            labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));        }    }    private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {        try (Graph g = new Graph()) {            GraphBuilder b = new GraphBuilder(g);            final int H = 224;            final int W = 224;            final float mean = 117f;            final float scale = 1f;            final Output input = b.constant("input", imageBytes);            final Output output =                    b.div(                            b.sub(                                    b.resizeBilinear(                                            b.expandDims(                                                    b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),                                                    b.constant("make_batch", 0)),                                            b.constant("size", new int[] {H, W})),                                    b.constant("mean", mean)),                            b.constant("scale", scale));            try (Session s = new Session(g)) {                return s.runner().fetch(output.op().name()).run().get(0);            }        }    }    private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {        try (Graph g = new Graph()) {            g.importGraphDef(graphDef);            try (Session s = new Session(g);                 Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {                final long[] rshape = result.shape();                if (result.numDimensions() != 2 || rshape[0] != 1) {                    throw new RuntimeException(                            String.format(                                    "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",                                    Arrays.toString(rshape)));                }                int nlabels = (int) rshape[1];                return result.copyTo(new float[1][nlabels])[0];            }        }    }    private static int maxIndex(float[] probabilities) {        int best = 0;        for (int i = 1; i < probabilities.length; ++i) {            if (probabilities[i] > probabilities[best]) {                best = i;            }        }        return best;    }    private static byte[] readAllBytesOrExit(Path path) {        try {            return Files.readAllBytes(path);        } catch (IOException e) {            System.err.println("Failed to read [" + path + "]: " + e.getMessage());            System.exit(1);        }        return null;    }    private static List<String> readAllLinesOrExit(Path path) {        try {            return Files.readAllLines(path, Charset.forName("UTF-8"));        } catch (IOException e) {            System.err.println("Failed to read [" + path + "]: " + e.getMessage());            System.exit(0);        }        return null;    }    static class GraphBuilder {        GraphBuilder(Graph g) {            this.g = g;        }        Output div(Output x, Output y) {            return binaryOp("Div", x, y);        }        Output sub(Output x, Output y) {            return binaryOp("Sub", x, y);        }        Output resizeBilinear(Output images, Output size) {            return binaryOp("ResizeBilinear", images, size);        }        Output expandDims(Output input, Output dim) {            return binaryOp("ExpandDims", input, dim);        }        Output cast(Output value, DataType dtype) {            return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);        }        Output decodeJpeg(Output contents, long channels) {            return g.opBuilder("DecodeJpeg", "DecodeJpeg")                    .addInput(contents)                    .setAttr("channels", channels)                    .build()                    .output(0);        }        Output constant(String name, Object value) {            try (Tensor t = Tensor.create(value)) {                return g.opBuilder("Const", name)                        .setAttr("dtype", t.dataType())                        .setAttr("value", t)                        .build()                        .output(0);            }        }        private Output binaryOp(String type, Output in1, Output in2) {            return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);        }        private Graph g;    }}

虽然有例子,可是公司的要求高一些,要实现跟踪的ssd模型,没找到例子,搞不明白怎么调用,然后就交给我魔改了,代码完全复制android的tensorflow.jar里的代码

TensorFlowInferenceInterface.java,类名是不是很熟,我基本就是复制粘贴

public class TensorFlowInferenceInterface {    private final String modelName;    private final Graph g;    private final Session sess;    private Session.Runner runner;    private List<String> feedNames = new ArrayList();    private List<Tensor> feedTensors = new ArrayList();    private List<String> fetchNames = new ArrayList();    private List<Tensor> fetchTensors = new ArrayList();    public TensorFlowInferenceInterface(String var2) {        this.modelName = var2;        this.g = new Graph();        this.sess = new Session(this.g);        this.runner = this.sess.runner();        Object var4 = null;        try {            var4 = new FileInputStream(var2);        } catch (IOException var8) {            throw new RuntimeException("Failed to load model from '" + var2 + "'", var8);        }        try {            byte[] var10 = new byte[((InputStream)var4).available()];            int var6 = ((InputStream)var4).read(var10);            if(var6 != var10.length) {                throw new IOException("read error: read only " + var6 + " of the graph, expected to read " + var10.length);            } else {                this.loadGraph(var10, this.g);                ((InputStream)var4).close();            }        } catch (IOException var7) {            throw new RuntimeException("Failed to load model from '" + var2 + "'", var7);        }    }    public TensorFlowInferenceInterface(InputStream var1) {        this.modelName = "";        this.g = new Graph();        this.sess = new Session(this.g);        this.runner = this.sess.runner();        try {            int var2 = var1.available() > 16384?var1.available():16384;            ByteArrayOutputStream var3 = new ByteArrayOutputStream(var2);            byte[] var5 = new byte[16384];            int var4;            while((var4 = var1.read(var5, 0, var5.length)) != -1) {                var3.write(var5, 0, var4);            }            byte[] var6 = var3.toByteArray();            this.loadGraph(var6, this.g);        } catch (IOException var7) {            throw new RuntimeException("Failed to load model from the input stream", var7);        }    }    public TensorFlowInferenceInterface(Graph var1) {        this.modelName = "";        this.g = var1;        this.sess = new Session(var1);        this.runner = this.sess.runner();    }    public void run(String[] var1) {        this.closeFetches();        String[] var3 = var1;        int var4 = var1.length;        for(int var5 = 0; var5 < var4; ++var5) {            String var6 = var3[var5];            this.fetchNames.add(var6);            TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6);            this.runner.fetch(var7.name, var7.outputIndex);        }        try {            this.fetchTensors = this.runner.run();        } catch (RuntimeException var11) {            throw var11;        } finally {            this.closeFeeds();            this.runner = this.sess.runner();        }    }    public Graph graph() {        return this.g;    }    public Operation graphOperation(String var1) {        Operation var2 = this.g.operation(var1);        if(var2 == null) {            throw new RuntimeException("Node '" + var1 + "' does not exist in model '" + this.modelName + "'");        } else {            return var2;        }    }    public void close() {        this.closeFeeds();        this.closeFetches();        this.sess.close();        this.g.close();    }    protected void finalize() throws Throwable {        try {            this.close();        } finally {            super.finalize();        }    }    public void feed(String var1, float[] var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2)));    }    public void feed(String var1, int[] var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, IntBuffer.wrap(var2)));    }    public void feed(String var1, long[] var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, LongBuffer.wrap(var2)));    }    public void feed(String var1, double[] var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, DoubleBuffer.wrap(var2)));    }    public void feed(String var1, byte[] var2, long... var3) {        this.addFeed(var1, Tensor.create(DataType.UINT8, var3, ByteBuffer.wrap(var2)));    }    public void feedString(String var1, byte[] var2) {        this.addFeed(var1, Tensor.create(var2));    }    public void feedString(String var1, byte[][] var2) {        this.addFeed(var1, Tensor.create(var2));    }    public void feed(String var1, FloatBuffer var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, var2));    }    public void feed(String var1, IntBuffer var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, var2));    }    public void feed(String var1, LongBuffer var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, var2));    }    public void feed(String var1, DoubleBuffer var2, long... var3) {        this.addFeed(var1, Tensor.create(var3, var2));    }    public void feed(String var1, ByteBuffer var2, long... var3) {        this.addFeed(var1, Tensor.create(DataType.UINT8, var3, var2));    }    public void fetch(String var1, float[] var2) {        this.fetch(var1, FloatBuffer.wrap(var2));    }    public void fetch(String var1, int[] var2) {        this.fetch(var1, IntBuffer.wrap(var2));    }    public void fetch(String var1, long[] var2) {        this.fetch(var1, LongBuffer.wrap(var2));    }    public void fetch(String var1, double[] var2) {        this.fetch(var1, DoubleBuffer.wrap(var2));    }    public void fetch(String var1, byte[] var2) {        this.fetch(var1, ByteBuffer.wrap(var2));    }    public void fetch(String var1, FloatBuffer var2) {        this.getTensor(var1).writeTo(var2);    }    public void fetch(String var1, IntBuffer var2) {        this.getTensor(var1).writeTo(var2);    }    public void fetch(String var1, LongBuffer var2) {        this.getTensor(var1).writeTo(var2);    }    public void fetch(String var1, DoubleBuffer var2) {        this.getTensor(var1).writeTo(var2);    }    public void fetch(String var1, ByteBuffer var2) {        this.getTensor(var1).writeTo(var2);    }    private void loadGraph(byte[] var1, Graph var2) throws IOException {        try {            var2.importGraphDef(var1);        } catch (IllegalArgumentException var7) {            throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage());        }    }    private void addFeed(String var1, Tensor var2) {        TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1);        this.runner.feed(var3.name, var3.outputIndex, var2);        this.feedNames.add(var1);        this.feedTensors.add(var2);    }    private Tensor getTensor(String var1) {        int var2 = 0;        for(Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) {            String var4 = (String)var3.next();            if(var4.equals(var1)) {                return (Tensor)this.fetchTensors.get(var2);            }        }        throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read");    }    private void closeFeeds() {        Iterator var1 = this.feedTensors.iterator();        while(var1.hasNext()) {            Tensor var2 = (Tensor)var1.next();            var2.close();        }        this.feedTensors.clear();        this.feedNames.clear();    }    private void closeFetches() {        Iterator var1 = this.fetchTensors.iterator();        while(var1.hasNext()) {            Tensor var2 = (Tensor)var1.next();            var2.close();        }        this.fetchTensors.clear();        this.fetchNames.clear();    }    private static class TensorId {        String name;        int outputIndex;        private TensorId() {        }        public static TensorFlowInferenceInterface.TensorId parse(String var0) {            TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId();            int var2 = var0.lastIndexOf(58);            if(var2 < 0) {                var1.outputIndex = 0;                var1.name = var0;                return var1;            } else {                try {                    var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1));                    var1.name = var0.substring(0, var2);                } catch (NumberFormatException var4) {                    var1.outputIndex = 0;                    var1.name = var0;                }                return var1;            }        }    }}

Classifier.java

public interface Classifier {    public class Recognition {        private final String id;        private final String title;        private final Float confidence;                private float left,top,right,bottom;        public Recognition(                final String id, final String title, final Float confidence, float left,float top,float right,float bottom) {            this.id = id;            this.title = title;            this.confidence = confidence;            this.left = left;            this.top = top;            this.right = right;            this.bottom = bottom;        }        public String getId() {            return id;        }        public String getTitle() {            return title;        }        public Float getConfidence() {            return confidence;        }        public float getLeft() {            return left;        }        public void setLeft(float left) {            this.left = left;        }        public float getTop() {            return top;        }        public void setTop(float top) {            this.top = top;        }        public float getRight() {            return right;        }        public void setRight(float right) {            this.right = right;        }        public float getBottom() {            return bottom;        }        public void setBottom(float bottom) {            this.bottom = bottom;        }        @Override        public String toString() {            String resultString = "";            if (id != null) {                resultString += "[" + id + "] ";            }            if (title != null) {                resultString += title + " ";            }            if (confidence != null) {                resultString += String.format("(%.1f%%) ", confidence * 100.0f);            }            if (left != 0) {                resultString += left + " ";            }            if (top != 0) {                resultString += top + " ";            }            if (right != 0) {                resultString += right + " ";            }            if (bottom != 0) {                resultString += bottom + " ";            }            return resultString.trim();        }    }    List<Recognition> recognizeImage(int[] byteValues);    void close();}

TensorFlowObjectDetectionAPIModel.java

public class TensorFlowObjectDetectionAPIModel implements Classifier{    private static final int MAX_RESULTS = 100;    private String inputName;    private int inputSize;    private Vector<String> labels = new Vector<String>();    private byte[] byteValues;    private float[] outputLocations;    private float[] outputScores;    private float[] outputClasses;    private float[] outputNumDetections;    private String[] outputNames;    private TensorFlowInferenceInterface inferenceInterface;    public static Classifier create(            final String modelFilename,            final String labelFilename,            final int inputSize) throws IOException {        final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();        InputStream labelsInput = new FileInputStream(labelFilename);        BufferedReader br = null;        br = new BufferedReader(new InputStreamReader(labelsInput));        String line;        while ((line = br.readLine()) != null) {            d.labels.add(line);        }        br.close();        d.inferenceInterface = new TensorFlowInferenceInterface(modelFilename);        final Graph g = d.inferenceInterface.graph();        d.inputName = "image_tensor";        final Operation inputOp = g.operation(d.inputName);        if (inputOp == null) {            throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");        }        d.inputSize = inputSize;        final Operation outputOp1 = g.operation("detection_scores");        if (outputOp1 == null) {            throw new RuntimeException("Failed to find output Node 'detection_scores'");        }        final Operation outputOp2 = g.operation("detection_boxes");        if (outputOp2 == null) {            throw new RuntimeException("Failed to find output Node 'detection_boxes'");        }        final Operation outputOp3 = g.operation("detection_classes");        if (outputOp3 == null) {            throw new RuntimeException("Failed to find output Node 'detection_classes'");        }        d.outputNames = new String[] {"detection_boxes", "detection_scores",                "detection_classes", "num_detections"};        d.byteValues = new byte[d.inputSize * d.inputSize * 3];        d.outputScores = new float[MAX_RESULTS];        d.outputLocations = new float[MAX_RESULTS * 4];        d.outputClasses = new float[MAX_RESULTS];        d.outputNumDetections = new float[1];        return d;    }    private TensorFlowObjectDetectionAPIModel() {}    @Override    public List<Recognition> recognizeImage(int[] intValues) {        for (int i = 0; i < intValues.length; ++i) {            byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);            byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);            byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);        }        inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);        inferenceInterface.run(outputNames);        outputLocations = new float[MAX_RESULTS * 4];        outputScores = new float[MAX_RESULTS];        outputClasses = new float[MAX_RESULTS];        outputNumDetections = new float[1];        inferenceInterface.fetch(outputNames[0], outputLocations);        inferenceInterface.fetch(outputNames[1], outputScores);        inferenceInterface.fetch(outputNames[2], outputClasses);        inferenceInterface.fetch(outputNames[3], outputNumDetections);        final PriorityQueue<Recognition> pq =                new PriorityQueue<Recognition>(1,new Comparator<Recognition>() {                    public int compare(final Recognition lhs, final Recognition rhs) {                        return Float.compare(rhs.getConfidence(), lhs.getConfidence());                    }                });        for (int i = 0; i < outputScores.length; ++i) {            float left = outputLocations[4 * i + 1] * inputSize;            float top = outputLocations[4 * i] * inputSize;            float right = outputLocations[4 * i + 3] * inputSize;            float bottom = outputLocations[4 * i + 2] * inputSize;            pq.add(                    new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], left,top,right,bottom));        }        final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();        for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {            recognitions.add(pq.poll());        }        return recognitions;    }    @Override    public void close() {        inferenceInterface.close();    }}

DetectionImage.java这个是实现类只需要复制实现代码

public class DetectionImage {    public static void main(String[] args) {        int input_size = 300;        Classifier d = null;        try {            d = TensorFlowObjectDetectionAPIModel.create("C:\\sts\\ssd_mobilenet_v1_android_export.pb",                    "C:\\sts\\coco_labels_list.txt", input_size);        } catch (IOException e) {            // TODO Auto-generated catch block            e.printStackTrace();        }        if(d != null){            File file = new File("C:\\sts\\person.jpg");            Image img = null;            try {                img = ImageIO.read(file);            } catch (IOException e) {                // TODO Auto-generated catch block                e.printStackTrace();            }            if(img != null){                int width = img.getWidth(null);                int height = img.getHeight(null);                BufferedImage image = new BufferedImage(input_size, input_size, BufferedImage.TYPE_INT_RGB);                image.getGraphics().drawImage(img, 0, 0, input_size, input_size,0,0,width,height,null);                int[] rgbs = new int[input_size * input_size];                image.getRGB(0, 0, input_size, input_size, rgbs, 0, input_size);                List<Recognition> results = d.recognizeImage(rgbs);                for (Recognition result : results) {                    System.out.println(result.toString());                }            }        }    }}

输出结果

[0] person (99.1%) 113.1657 27.434679 183.88785 296.03595

没想到java用起来还是蛮6的,果然android不行了还可以转行java。。。还是程序猿。。。


































原创粉丝点击