利用Spark mllib识别点阵文本

来源:互联网 发布:2万工资的java程序员 编辑:程序博客网 时间:2024/06/05 21:16

Step 1

准备手写字体,生成图片;
总共写了10个字:你、我、他、分、布、式、计、算、框、架,每个写了10遍
然后写了5个待识别的字:你、我、好、世、界、框、架

图片如下(手机上写的,字丑见谅!)
这里写图片描述

Step 2

切割图片(抠图),对齐大小至64*64,输出二值化(0-1)点阵,参考了网上的部分代码,java源码如下:

import java.awt.Color;import java.awt.image.BufferedImage;import java.io.File;import java.io.IOException;import java.util.ArrayList;import java.util.HashMap;import javax.imageio.ImageIO;public class ImageTest{    static int NORMAL_WIDTH = 64;    static int NORMAL_HEIGHT = 64;    static String FILE_DIR = "/Users/bluejoe/testdata/pics";    public static BufferedImage validateArea(File file) throws IOException    {        BufferedImage bi = ImageIO.read(file);        // 获取当前图片的高,宽,ARGB        int h = bi.getHeight();        int w = bi.getWidth();        int arr[][] = new int[w][h];        // 获取图片每一像素点的灰度值        for (int i = 0; i < w; i++)        {            for (int j = 0; j < h; j++)            {                // getRGB()返回默认的RGB颜色模型(十进制)                arr[i][j] = getImageRgb(bi.getRGB(i, j));// 该点的灰度值            }        }        int left = w - 1, top = h - 1, right = 0, bottom = 0;        int FZ = 130;        for (int i = 0; i < w; i++)        {            for (int j = 0; j < h; j++)            {                if (getGray(arr, i, j, w, h) > FZ)                {                    if (i > right)                        right = i;                    if (j > bottom)                        bottom = j;                    if (i < left)                        left = i;                    if (j < top)                        top = j;                }            }        }        BufferedImage croped = bi.getSubimage(left, top, right - left + 1,                bottom - top + 1);        BufferedImage resized = new BufferedImage(NORMAL_WIDTH, NORMAL_HEIGHT,                BufferedImage.TYPE_INT_ARGB);        resized.getGraphics().drawImage(croped, 0, 0, NORMAL_WIDTH,                NORMAL_HEIGHT, null);        /*         * File file1 = new File(file.getPath() + ".1_.jpg");         * ImageIO.write(resized, "png", file1);         *          * return ImageIO.read(file1);         */        return resized;    }    public static void main(String[] args) throws IOException    {        File dir = new File(FILE_DIR);        HashMap<String, ArrayList<Integer>> filePoints = new HashMap<String, ArrayList<Integer>>();        for (File file : dir.listFiles())        {            if (!file.isFile() || file.isHidden()                    || file.getPath().endsWith("_.jpg"))                continue;            try            {                BufferedImage bi = validateArea(file);                // 获取当前图片的高,宽,ARGB                int h = bi.getHeight();                int w = bi.getWidth();                int arr[][] = new int[w][h];                // 获取图片每一像素点的灰度值                for (int i = 0; i < w; i++)                {                    for (int j = 0; j < h; j++)                    {                        // getRGB()返回默认的RGB颜色模型(十进制)                        arr[i][j] = getImageRgb(bi.getRGB(i, j));// 该点的灰度值                    }                }                BufferedImage bufferedImage = new BufferedImage(w, h,                        BufferedImage.TYPE_BYTE_BINARY);// 构造一个类型为预定义图像类型之一的                                                        // BufferedImage,TYPE_BYTE_BINARY(表示一个不透明的以字节打包的                                                        // 1、2 或 4 位图像。)                // ArrayList<ArrayList<Integer>> arr2 = new                // ArrayList<ArrayList<Integer>>();                int FZ = 130;                // System.err.println(file.getPath());                ArrayList<Integer> points = new ArrayList<Integer>();                for (int i = 0; i < h; i++)                {                    for (int j = 0; j < w; j++)                    {                        if (getGray(arr, j, i, w, h) > FZ)                        {                            int black = new Color(255, 255, 255).getRGB();                            bufferedImage.setRGB(j, i, black);                            points.add(0);                        }                        else                        {                            int white = new Color(0, 0, 0).getRGB();                            bufferedImage.setRGB(j, i, white);                            points.add(1);                        }                    }                }                filePoints.put(file.getName(), points);                System.err.println(String.format("(%s,%s)",                        file.getName().charAt(0) - '0', points));                /*                 * File file2 = new File(file.getPath() + ".2_.jpg");                 * ImageIO.write(bufferedImage, "jpg", file2);                 */            }            catch (Throwable e)            {                e.printStackTrace();            }        }    }    private static int getImageRgb(int i)    {        String argb = Integer.toHexString(i);// 将十进制的颜色值转为十六进制        // argb分别代表透明,红,绿,蓝 分别占16进制2位        int r = Integer.parseInt(argb.substring(2, 4), 16);// 后面参数为使用进制        int g = Integer.parseInt(argb.substring(4, 6), 16);        int b = Integer.parseInt(argb.substring(6, 8), 16);        int result = (int) ((r + g + b) / 3);        return result;    }    // 自己加周围8个灰度值再除以9,算出其相对灰度值    public static int getGray(int gray[][], int x, int y, int w, int h)    {        int rs = gray[x][y] + (x == 0 ? 255 : gray[x - 1][y])                + (x == 0 || y == 0 ? 255 : gray[x - 1][y - 1])                + (x == 0 || y == h - 1 ? 255 : gray[x - 1][y + 1])                + (y == 0 ? 255 : gray[x][y - 1])                + (y == h - 1 ? 255 : gray[x][y + 1])                + (x == w - 1 ? 255 : gray[x + 1][y])                + (x == w - 1 || y == 0 ? 255 : gray[x + 1][y - 1])                + (x == w - 1 || y == h - 1 ? 255 : gray[x + 1][y + 1]);        return rs / 9;    }}

抠完的图很多很多,见下图:
这里写图片描述

Step 3

将如上输出分别存入2个文件,一个points.txt(10个汉字的手写点阵),一个query.txt(待识别的汉字点阵);

文件下载地址:文本点阵文件

Step 4

启动spark-shell,加载并采用LogisticRegressionWithLBFGS算法识别:

val data = MLUtils.loadLabeledPoints(sc,"file:///Users/bluejoe/testdata/points.txt");val query = MLUtils.loadLabeledPoints(sc,"file:///Users/bluejoe/testdata/query.txt").collect.map(_.features)val model = new LogisticRegressionWithLBFGS().setNumClasses(10).run(data)

识别第1、5个字(我,框)看看:

scala> model.predict(query.collect()(1))res91: Double = 1.0scala> model.predict(query.collect()(5))res92: Double = 8.0

全部识别出来看看:

scala> model.predict(query).collectres82: Array[Double] = Array(0.0, 1.0, 8.0, 8.0, 4.0, 8.0, 9.0)

结果是,存在的字被正确识别了,不存在的字识别失败!仔细看了一下源码,LogisticRegressionWithLBFGS没有一个同时输出匹配率的方法,它只是简单的挑选了一个匹配度比较高的分类,所以它总能输出一个分类。