Spark Java 用 KMeans算法实现图片压缩

来源:互联网 发布:剑三咩太捏脸数据 编辑:程序博客网 时间:2024/05/16 11:59

压缩前:981 KB
这里写图片描述
压缩后:111 KB
这里写图片描述

思路:
取得图片每一点的像素,组成向量Vector如下:(w,h,R,G,B);
设置目的K值,训练所有点,获得KMeansModel;
此遍历所有的点,利用模型预测每个点属于哪个 中心点,同时改变这个点的R,G,B值使这个点的颜色 与这个点所在的集合相同;
重新利用收集的数据画出图片。

一共需要两个类,一个处理跟图片相关,一个处理KMeans算法,好像pytho语言写就好简单,哎。

处理图片类如下:

import java.awt.Color;  import java.awt.image.BufferedImage;  import java.io.File;  import java.io.FileOutputStream;  import java.io.IOException;  import java.io.OutputStream;import java.util.ArrayList;import java.util.List;import javax.imageio.ImageIO;import scala.Tuple3;  public class AnalyzePicture {      /**     *      * @param oldFilePath:你想压缩的图片的地址     * @param newFilePath:压缩后存放的地址     * @param t:tuple类型(scala里面的,因为在学习spark所以就拿来用),<Integer,Integer,Integer>-><weight,high,RGB>     */    public static void generatePhoto(String oldFilePath,String newFilePath,List<Tuple3<Integer,Integer,Integer>>t)    {        try{            BufferedImage imgOld = ImageIO.read(new File(oldFilePath));            int w=imgOld.getWidth();            int h=imgOld.getHeight();            File out = new File(newFilePath);              if (!out.exists())                  out.createNewFile();              OutputStream output = new FileOutputStream(out);              BufferedImage imgOut = new BufferedImage(w, h,                      BufferedImage.TYPE_3BYTE_BGR);             for(Tuple3<Integer,Integer,Integer> tt:t)            {                imgOut.setRGB( tt._1(), tt._2(),tt._3());            }            ImageIO.write(imgOut, "png", output);              output.close();        }catch(Exception e)        {            e.printStackTrace();        }    }    /**     * 作用是返回图片每一点的RGP值     * @param filePath :想要处理的图片的地址     * @return String: with,high,R,G,P,这里我把RGP独立成三个方向的值     */    public static List<String> getImageGRBStr(String filePath) {        File file  = new File(filePath);        List<String>list=new ArrayList<String>();        if (!file.exists()) {            return null;        }        try {            BufferedImage bufImg = ImageIO.read(file);            int height = bufImg.getHeight();            int width = bufImg.getWidth();            for (int i = bufImg.getMinX(); i < width; i++) {                for (int j = bufImg.getMinY(); j < height; j++) {                    String str=i+","+j+","+((bufImg.getRGB(i, j) & 0xff0000) >> 16)+","+((bufImg.getRGB(i, j) & 0xff00) >> 8)+","+((bufImg.getRGB(i, j) & 0xff));                    list.add(str);                    }            }        } catch (IOException e) {            // TODO Auto-generated catch block            e.printStackTrace();        }        return list;    }}  

KMeans算法类如下:

import java.util.ArrayList;import java.util.List;import org.apache.log4j.Level;import org.apache.log4j.Logger;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.Function;import org.apache.spark.mllib.clustering.KMeans;import org.apache.spark.mllib.clustering.KMeansModel;import org.apache.spark.mllib.linalg.Vector;import org.apache.spark.mllib.linalg.Vectors;import org.apache.spark.rdd.RDD;import org.apache.spark.sql.SparkSession;import scala.Tuple3;public class PhotoKMeans {    public static void main(String[]args)    {        //屏幕spark多余的log        Logger.getLogger("org.apache.spark").setLevel(Level.WARN);        Logger.getLogger("org.apache.spark").setLevel(Level.OFF);        //获取到到了图片的每个点的情况String->w,h,R,G,P        List<String>list=AnalyzePicture.getImageGRBStr(args[0]);        //设置spark为本地模式,这样就不要老是跑到linux集群上去跑了        SparkSession spark=SparkSession.builder().master("local").appName("PhotoKMeans").getOrCreate();        //目的,初始化非file的数据源        JavaSparkContext javaSpark=new JavaSparkContext(spark.sparkContext());        //首先把List<String>变成RDD<String>,RDD中的String为"w,h,R,G,P",去除","之后,RDD中的String映射为向量 {w,h,R,G,P}        RDD<Vector>rdd=javaSpark.parallelize(list).map(new StringToVector(",")).rdd().cache();        KMeansModel model=KMeans.train(rdd, Integer.parseInt(args[2]), Integer.parseInt(args[3]));//拿去处理        //获得处理后的中心向量         Vector[]vectors=model.clusterCenters();        //获得图片所有点的向量 ,为了下面获取每个点属于哪个中心点而准备        List<Vector>points=rdd.toJavaRDD().collect();        //存储重新调整之后图片每个点的w,h,RGP        List<Tuple3<Integer,Integer,Integer>>tuple=new ArrayList<Tuple3<Integer,Integer,Integer>>();        for(Vector v:points)        {            int cluster=model.predict(v);//start from 0            double[]rgbs=vectors[cluster].toArray();            double[]x=v.toArray();            int ww=(int)x[0];            int hh=(int)x[1];            int rgb=((int) rgbs[2]<<16)+((int) rgbs[3]<<8)+((int) rgbs[4]);//这里是将R,G,P变成RGP            Tuple3<Integer,Integer,Integer>t=new Tuple3<>(ww,hh,rgb);            tuple.add(t);        }        //交给处理类生成新的图片        AnalyzePicture.generatePhoto(args[0], args[1], tuple);    }    public static class StringToVector implements Function<String,Vector>    {        String target="";        public StringToVector(String target)        {            this.target=target;        }        @Override        public Vector call(String a) throws Exception {            // TODO Auto-generated method stub            String[]aa=a.split(target);            double[]aaa=new double[aa.length];            for(int i=0;i<aa.length;i++)            {                aaa[i]=Double.parseDouble(aa[i]);            }            return Vectors.dense(aaa);        }       }}

最用调用KMeans类,参数如下:
args[0] =你想处理的图片地址
args[1] =处理后新图片的地址
args[2] =K值(K越大图片越逼真)
args[3] =允许的最大迭代次数

C:/Users/Administrator/Pictures/fengjing.png C:/Users/Administrator/Pictures/2_564.png 600 5

原创粉丝点击