java加载tensorflow训练好的模型部署成service
来源:互联网 发布:苏州爱知科技福利待遇 编辑:程序博客网 时间:2024/05/29 15:25
在上面一章节提到怎么在java中怎么调用tensorflow训练好的模型,这篇主要是部署成service代码,看看吧,还有个东西官方说要用jdk1.8,不过我把部分方法改了,1.7也可以用,看看吧:
首先是utils,里面用到的一些方法,把一段文本转化为一个tensor
package com.dianping.text.classify.util;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
public class TersorflowUtils {
- private static Map<String, Integer> word_to_id = new HashMap<String, Integer>();
public static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
public static byte[] readAllBytes(String path) {
try {
InputStream in = new FileInputStream(path);
byte[] bytes = new byte[in.available()];
in.read(bytes);
in.close();
return bytes;
} catch (Exception e) {
return null;
}
}
/*
* 序列默人长度为300
*/
public static int[][] gettexttoid(String text) {
int[][] xpad = new int[1][300];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (word_to_id.containsKey(element)) {
list.add(word_to_id.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= 300) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
public static int[][] gettexttoid(String text, Map<String, Integer> map) {
int[][] xpad = new int[1][300];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (map.containsKey(element)) {
list.add(map.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= 300) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
/*
* 自定义长度
*/
public static int[][] gettexttoid(String text, int maxlen) {
if (maxlen < 1) {
throw new IllegalArgumentException("maxlen长度必须大于等于1");
}
int[][] xpad = new int[1][maxlen];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (word_to_id.containsKey(element)) {
list.add(word_to_id.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= maxlen) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
public static int[][] gettexttoid(String text, int maxlen, Map<String, Integer> map) {
if (maxlen < 1) {
throw new IllegalArgumentException("maxlen长度必须大于等于1");
}
int[][] xpad = new int[1][maxlen];
if (StringUtils.isBlank(text)) {
return xpad;
}
char[] chs = text.trim().toLowerCase().toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < chs.length; i++) {
String element = Character.toString(chs[i]);
if (map.containsKey(element)) {
list.add(map.get(element));
}
}
if (list.size() == 0) {
return xpad;
}
int size = list.size();
Integer[] targetInter = (Integer[]) list.toArray(new Integer[size]);
/*
* 用于jdk1.8转换
*/
// int[] target=
// Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
int[] target = Intetoint(targetInter);
if (size <= maxlen) {
System.arraycopy(target, 0, xpad[0], xpad[0].length - size, target.length);
} else {
System.arraycopy(target, size - xpad[0].length, xpad[0], 0, xpad[0].length);
}
return xpad;
}
private static int[] Intetoint(Integer[] arr) {
int[] result = new int[arr.length];
for (int i = 0; i < arr.length; i++) {
result[i] = arr[i].intValue();
}
return result;
}
public static double getClassifyByBiLSTM(String text, Session sess, Map<String, Integer> map, Tensor keep_prob) {
if (StringUtils.isBlank(text)) {
return 0.0;
}
int[][] arr = gettexttoid(text, map);
Tensor input = Tensor.create(arr);
Tensor result = sess.runner().feed("input_x", input).feed("keep_prob", keep_prob).fetch("score/pred_y").run()
.get(0);
long[] rshape = result.shape();
int nlabels = (int) rshape[1];
int batchSize = (int) rshape[0];
float[][] logits = result.copyTo(new float[batchSize][nlabels]);
if (nlabels > 1 && batchSize > 0) {
return logits[0][1];
}
return 0.0;
}
}
其次是service启动项:
package com.dianping.text.classify.base;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import com.dianping.text.classify.util.TersorflowUtils;
import com.dianping.text.classifybydl.api.service.Category;
public class Abuse implements Category {
private static final Logger logger = LoggerFactory.getLogger(Abuse.class);
private Graph g;
private Session sess;
private Tensor keep_prob;
private Map<String, Integer> map;
private void init() {
g = new Graph();
keep_prob = Tensor.create(1.0f);
try {
updataMap();
byte[] graphDef = TersorflowUtils
.readAllBytesOrExit(Paths.get(this.getClass().getResource("/").getPath(), "modelabuse/graph.model"));
g.importGraphDef(graphDef);
sess = new Session(g);
} catch (Exception e) {
logger.error(" model load:", e);
}
}
public void updataMap() {
map = new HashMap<>();
int i = 0;
try {
BufferedReader buffer = null;
String path = this.getClass().getResource("/").getPath() + "modelabuse/vocab_cnews.txt";
buffer = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
String line = buffer.readLine().trim();
while (line != null) {
map.put(line, i++);
line = buffer.readLine().trim();
}
buffer.close();
} catch (Exception e) {
}
System.out.println("map.size is:" + map.size());
}
@Override
public double getClassify(String text) {
return TersorflowUtils.getClassifyByBiLSTM(text, sess, map, keep_prob);
}
public static void main(String[] args) {
Abuse abuse = new Abuse();
abuse.init();
System.out.println(abuse.getClassify("我操你妈个逼"));
}
}
结果:
阅读全文
0 0
- java加载tensorflow训练好的模型部署成service
- 将tensorflow训练好的模型部署成sercice服务,并做预测
- 如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件
- tensorflow训练好的模型中java调用
- tensorflow将训练好的模型
- C++调用tensorflow 训练好的模型
- 将caffe训练好的模型转换为tensorflow模型
- tensorflow 加载预训练模型
- 将tensorflow训练好的模型移植到android
- TensorFlow 训练好模型参数的保存和恢复代码
- 将tensorflow训练好的模型移植到android
- 将tensorflow训练好的模型移植到android
- Tensorflow 05: 导入预训练好的图模型
- 将tensorflow训练好的模型移植到android
- 将tensorflow训练好的模型移植到android
- TensorFlow使用C++加载使用训练好的模型,.cc文件代码实现的相关类及方法总结
- TensorFlow保存和加载训练模型
- tensorflow保存加载模型查看训练参数
- java 异常
- Spark组件介绍
- 悬镜安全丨企业应急响应浅析,遇到网络攻击怎么办?
- 最短路径
- 20171023memo
- java加载tensorflow训练好的模型部署成service
- C#事件
- imresize
- JAVA面向对象练习02
- Qt的setMouseTracking使用
- 二叉树中和为某一值的路径
- 「游戏引擎Mojoc」(3)C面向对象编程
- Scientific Toolworks Understand 4.0.909 Win32_64 2CD
- JVM类加载机制(类加载过程和类加载器)