m维空间里n个点每点最近的第k个点的距离

来源:互联网 发布:java工厂模式概念 编辑:程序博客网 时间:2024/06/08 03:05

题目如图。

m=4,即点均为四维空间的点。

n数目不定,可以理解为几万,几十万甚至上千万。

使用spark计算。资源配置为:executor-cores:6,executor-memory:10G。

 

解法一:

首先将点的矩阵弄成dataframe(dataframe里每一个Row的内容均为:[uuid,double1,double2,double3,double4])

然后dataframe join自身,然后再map求每一行(每点)与其他点的距离,并返回JavaPairRDD,key为点的uuid,value为欧几里得距离。

然后再groupByKey,这样得到的JavaPairRDD的key为点的uuid,value为iteratorable即是该点与其他点距离的数组。

然后再map,sort每个点的距离iterable,从而得到每个点最近的第k个点与它的距离。

代码如下:


性能:n=10000的情况下,大概需要运行3分钟

瓶颈分析:

1.      计算复杂度为n^2,10000的情况下达到亿级别

2.      uuid为字符串,增加了分布式shuffle时移动的数据量

3.      欧几里得距离需要乘方和开发,浮点数计算比较消耗性能


解法2:

1.      将欧几里得距离换为曼哈顿距离

2.      uuid使用zipWithUniqueId生成,是long

3.      不使用join,使用RDD的cartesian方法生成

性能:与解法一差异不大



综上可知,主要还是求每个点与其他点的距离(即解法一的join,解法二的cartesian,非常耗时)。

如果要在本质上解决问题,需要将【求每个点与其他点的距离】这一步给剪掉。


k-d树是一种二叉树,它主要是循环的根据curDimension%DimensionNum进行分割子树。

而搜索邻近k个点时,只需要维护一个优先队列二叉式的搜索即可。即省略了计算每个点与其他点的距离这一部分。

代码如下:

HPoint.java

public class HPoint implements Serializable {    protected double[] coord;    protected HPoint(int n) {        coord = new double[n];    }    protected HPoint(double[] x) {        coord = new double[x.length];        for (int i = 0; i < x.length; ++i) {            coord[i] = x[i];        }    }    protected Object clone() {        return new HPoint(coord);    }    protected boolean equals(HPoint p) {        for (int i = 0; i < coord.length; ++i) {            if (coord[i] != p.coord[i]) {                return false;            }        }        return true;    }    //曼哈顿距离    protected static double manhanttandist(HPoint x,HPoint y){        double dist = 0;        for(int i=0; i < x.coord.length; ++i){            dist += Math.abs(x.coord[i] - y.coord[i]);        }        return dist;    }    //平方距离    protected static double sqrdist(HPoint x, HPoint y) {        double dist = 0;        for (int i = 0; i < x.coord.length; ++i) {            double diff = (x.coord[i] - y.coord[i]);            dist += (diff * diff);        }        return dist;    }    //欧几里得距离    protected static double eucdist(HPoint x, HPoint y) {        return Math.sqrt(sqrdist(x, y));    }    public String toString() {        String s = "";        for (int i = 0; i < coord.length; ++i) {            s = s + coord[i] + " ";        }        return s;    }}

HRect.java

public class HRect implements Serializable {    protected HPoint min;    protected HPoint max;    protected HRect(int ndims) {        min = new HPoint(ndims);        max = new HPoint(ndims);    }    protected HRect(HPoint vmin, HPoint vmax) {        min = (HPoint) vmin.clone();        max = (HPoint) vmax.clone();    }    protected Object clone() {        return new HRect(min, max);    }    //返回区域里距离HPoint距离最近的点    protected HPoint closest(HPoint t) {        HPoint p = new HPoint(t.coord.length);        for (int i = 0; i < t.coord.length; ++i) {            if (t.coord[i] <= min.coord[i]) {                p.coord[i] = min.coord[i];            } else if (t.coord[i] >= max.coord[i]) {                p.coord[i] = max.coord[i];            } else {                p.coord[i] = t.coord[i];            }        }        return p;    }    //初始化d维度的区域    protected static HRect infiniteHRect(int d) {        HPoint vmin = new HPoint(d);        HPoint vmax = new HPoint(d);        for (int i = 0; i < d; ++i) {            vmin.coord[i] = Double.NEGATIVE_INFINITY;            vmax.coord[i] = Double.POSITIVE_INFINITY;        }        return new HRect(vmin, vmax);    }    public String toString() {        return min + "\n" + max + "\n";    }}

KDNode.java

//kd树节点public class KDNode<T> implements Serializable {    protected HPoint k;    protected KDNode left, right;    protected boolean deleted;    T v;    private KDNode(HPoint key, T val) {        k = key;        v = val;        left = null;        right = null;        deleted = false;    }    //插入节点    protected static KDNode ins(HPoint key, Object val, KDNode t, int lev, int K) {        if (t == null) {            t = new KDNode(key, val);        } else if (key.equals(t.k)) {            //插入的值与该节点重复;如果该节点被标记为已删除,则将此节点恢复为未删除状态            if (t.deleted) {                t.deleted = false;                t.v = val;            }        } else if (key.coord[lev] > t.k.coord[lev]) {            t.right = ins(key, val, t.right, (lev + 1) % K, K);        } else {            t.left = ins(key, val, t.left, (lev + 1) % K, K);        }        return t;    }    //搜索节点    protected static KDNode srch(HPoint key, KDNode t, int K) {        for (int lev = 0; t != null; lev = (lev + 1) % K) {            if (!t.deleted && key.equals(t.k)) {                return t;            } else if (key.coord[lev] > t.k.coord[lev]) {                t = t.right;            } else {                t = t.left;            }        }        return null;    }    protected static void rsearch(HPoint lowk, HPoint uppk, KDNode t, int lev, int K, Vector<KDNode> v) {        if (t == null) {            return;        }        if (lowk.coord[lev] <= t.k.coord[lev]) {            rsearch(lowk, uppk, t.left, (lev + 1) % K, K, v);        }        int j;        for (j = 0; j < K && lowk.coord[j] <= t.k.coord[j] && uppk.coord[j] >= t.k.coord[j]; j++)            ;        if (j == K) {            v.add(t);        }        if (uppk.coord[lev] > t.k.coord[lev]) {            rsearch(lowk, uppk, t.right, (lev + 1) % K, K, v);        }    }    //近邻搜索    protected static void nnbr(KDNode kd, HPoint target, HRect hr, double max_dist_sqd, int lev, int K,                               NearestNeighborList nnl) {        if (kd == null) {            return;        }        int s = lev % K;        HPoint pivot = kd.k;        double pivot_to_target = HPoint.manhanttandist(pivot, target);        HRect left_hr = hr;        HRect right_hr = (HRect) hr.clone();        left_hr.max.coord[s] = pivot.coord[s];        right_hr.min.coord[s] = pivot.coord[s];        boolean target_in_left = target.coord[s] < pivot.coord[s];        KDNode nearer_kd;        HRect nearer_hr;        KDNode further_kd;        HRect further_hr;        if (target_in_left) {            nearer_kd = kd.left;            nearer_hr = left_hr;            further_kd = kd.right;            further_hr = right_hr;        }        else {            nearer_kd = kd.right;            nearer_hr = right_hr;            further_kd = kd.left;            further_hr = left_hr;        }        nnbr(nearer_kd, target, nearer_hr, max_dist_sqd, lev + 1, K, nnl);        KDNode nearest = (KDNode) nnl.getHighest();        double dist_sqd;        if (!nnl.isCapacityReached()) {            dist_sqd = Double.MAX_VALUE;        } else {            dist_sqd = nnl.getMaxPriority();        }        max_dist_sqd = Math.min(max_dist_sqd, dist_sqd);        HPoint closest = further_hr.closest(target);        if (Double.valueOf(HPoint.manhanttandist(closest, target)).compareTo(max_dist_sqd) < 0) {            if (pivot_to_target < dist_sqd) {                nearest = kd;                dist_sqd = pivot_to_target;                if (!kd.deleted) {                    nnl.insert(kd, dist_sqd);                }                if (nnl.isCapacityReached()) {                    max_dist_sqd = nnl.getMaxPriority();                } else {                    max_dist_sqd = Double.MAX_VALUE;                }            }            nnbr(further_kd, target, further_hr, max_dist_sqd, lev + 1, K, nnl);            KDNode temp_nearest = (KDNode) nnl.getHighest();            double temp_dist_sqd = nnl.getMaxPriority();            if (temp_dist_sqd < dist_sqd) {                nearest = temp_nearest;                dist_sqd = temp_dist_sqd;            }        }        else if (pivot_to_target < max_dist_sqd) {            nearest = kd;            dist_sqd = pivot_to_target;        }    }    private static String pad(int n) {        String s = "";        for (int i = 0; i < n; ++i) {            s += " ";        }        return s;    }    private static void hrcopy(HRect hr_src, HRect hr_dst) {        hpcopy(hr_src.min, hr_dst.min);        hpcopy(hr_src.max, hr_dst.max);    }    private static void hpcopy(HPoint hp_src, HPoint hp_dst) {        for (int i = 0; i < hp_dst.coord.length; ++i) {            hp_dst.coord[i] = hp_src.coord[i];        }    }    protected String toString(int depth) {        String s = k + "  " + v + (deleted ? "*" : "");        if (left != null) {            s = s + "\n" + pad(depth) + "L " + left.toString(depth + 1);        }        if (right != null) {            s = s + "\n" + pad(depth) + "R " + right.toString(depth + 1);        }        return s;    }}

KDTree.java

/** * kd树 */public class KDTree<T> implements java.io.Serializable {    //维度    private int m_K;    //根节点    private KDNode m_root;    //树里的节点个数    private int m_count;    //创建一个k维的k-d树    public KDTree(int k) {        m_K = k;        m_root = null;    }    //向kd树里插入一个节点    //key是k维的值    //value是节点的标签    public void insert(double[] key, T value) {        if (key.length != m_K) {            throw new RuntimeException("KDTree: wrong key size!");        } else {            m_root = KDNode.ins(new HPoint(key), value, m_root, 0, m_K);        }        m_count++;    }    //根据key数值,搜索kd树节点    public Object search(double[] key) {        if (key.length != m_K) {            throw new RuntimeException("KDTree: wrong key size!");        }        KDNode kd = KDNode.srch(new HPoint(key), m_root, m_K);        return (kd == null ? null : kd.v);    }    //删除kd树节点    public void delete(double[] key) {        if (key.length != m_K) {            throw new RuntimeException("KDTree: wrong key size!");        } else {            KDNode t = KDNode.srch(new HPoint(key), m_root, m_K);            if (t == null) {                throw new RuntimeException("KDTree: key missing!");            } else {                t.deleted = true;            }            m_count--;        }    }    //搜索距离最近的kd树节点    public T nearest(double[] key) {        List<T> nbrs = nearest(key, 1);        return nbrs.get(0);    }    //搜索最近的n个kd树节点    public List<T> nearest(double[] key, int n) {        if (n < 0 || n > m_count) {            throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot"                    + " be negative or greater than number of nodes (" + m_count + ").");        }        if (key.length != m_K) {            throw new RuntimeException("KDTree: wrong key size!");        }        List<T> nbrs = new ArrayList<T>(n);        NearestNeighborList nnl = new NearestNeighborList(n);        HRect hr = HRect.infiniteHRect(key.length);        double max_dist_sqd = Double.MAX_VALUE;        HPoint keyp = new HPoint(key);        KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl);        for (int i = 0; i < n; ++i) {            KDNode<T> kd = (KDNode) nnl.removeHighest();            nbrs.add(kd.v);        }        return nbrs;    }    private double mandist(double[] p1,double[] p2){        double dist = 0.0;        for(int i=0;i<p1.length;i++){            dist += Math.abs(p1[i]-p2[i]);        }        return dist;    }    //搜索最近的n个kd树节点,返回与他们的的距离    public List<Double> nearestDistance(double[] key, int n) {        if (n < 0 || n > m_count) {            throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot"                    + " be negative or greater than number of nodes (" + m_count + ").");        }        if (key.length != m_K) {            throw new RuntimeException("KDTree: wrong key size!");        }        List<Double> nbrs = new ArrayList<Double>(n);        NearestNeighborList nnl = new NearestNeighborList(n);        HRect hr = HRect.infiniteHRect(key.length);        double max_dist_sqd = Double.MAX_VALUE;        HPoint keyp = new HPoint(key);        KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl);        for (int i = 0; i < n; ++i) {            KDNode<T> kd = (KDNode) nnl.removeHighest();            nbrs.add(mandist(kd.k.coord,key));        }        return nbrs;    }    public String toString() {        return m_root.toString(0);    }}

NeareastNeighbor.java

/** * 最近邻居列表,基于优先队列实现 */public class NearestNeighborList implements Serializable {    public static int REMOVE_HIGHEST = 1;    public static int REMOVE_LOWEST = 2;    PriorityQueue m_Queue = null;    int m_Capacity = 0;    //只保存最近的capacity个邻居    public NearestNeighborList(int capacity) {        m_Capacity = capacity;        m_Queue = new PriorityQueue(m_Capacity, Double.POSITIVE_INFINITY);    }    public double getMaxPriority() {        if (m_Queue.length() == 0) {            return Double.POSITIVE_INFINITY;        }        return m_Queue.getMaxPriority();    }    public boolean insert(Object object, double priority) {        if (m_Queue.length() < m_Capacity) {            //如果尚未达到capacity个,则直接放入队列            m_Queue.add(object, priority);            return true;        }        if (priority > m_Queue.getMaxPriority()) {            //如果优先级比队列里的其他元素都大,则入不了队列            return false;        }        //移除队列中优先级最大的元素,即队尾元素        m_Queue.remove();        //将新元素插入        m_Queue.add(object, priority);        return true;    }    public boolean isCapacityReached() {        return m_Queue.length() >= m_Capacity;    }    public Object getHighest() {        return m_Queue.front();    }    public boolean isEmpty() {        return m_Queue.length() == 0;    }    public int getSize() {        return m_Queue.length();    }    public Object removeHighest() {        return m_Queue.remove();    }}

PriorityQueue.java

/** * 优先队列,优先级越低的越在队列前 */public class PriorityQueue implements Serializable {    private double maxPriority = Double.MAX_VALUE;    private Object[] data;    private double[] value;    private int count;    private int capacity;    public PriorityQueue() {        init(20);    }    public PriorityQueue(int capacity) {        init(capacity);    }    public PriorityQueue(int capacity, double maxPriority) {        this.maxPriority = maxPriority;        init(capacity);    }    private void init(int size) {        capacity = size;        data = new Object[capacity + 1];        value = new double[capacity + 1];        value[0] = maxPriority;        data[0] = null;    }    public void add(Object element, double priority) {        if (count++ >= capacity) {            expandCapacity();        }        value[count] = priority;        data[count] = element;        bubbleUp(count);    }    public Object remove() {        if (count == 0) {            return null;        }        Object element = data[1];        data[1] = data[count];        value[1] = value[count];        data[count] = null;        value[count] = 0L;        count--;        bubbleDown(1);        return element;    }    public Object front() {        return data[1];    }    public double getMaxPriority() {        return value[1];    }    private void bubbleDown(int pos) {        Object element = data[pos];        double priority = value[pos];        int child;        for (; pos * 2 <= count; pos = child) {            child = pos * 2;            if (child != count) {                if (value[child] < value[child + 1]) {                    child++;                }            }            if (priority < value[child]) {                value[pos] = value[child];                data[pos] = data[child];            } else {                break;            }        }        value[pos] = priority;        data[pos] = element;    }    private void bubbleUp(int pos) {        Object element = data[pos];        double priority = value[pos];        while (value[pos / 2] < priority) {            value[pos] = value[pos / 2];            data[pos] = data[pos / 2];            pos /= 2;        }        value[pos] = priority;        data[pos] = element;    }    private void expandCapacity() {        capacity = count * 2;        Object[] elements = new Object[capacity + 1];        double[] prioritys = new double[capacity + 1];        System.arraycopy(data, 0, elements, 0, data.length);        System.arraycopy(value, 0, prioritys, 0, data.length);        data = elements;        value = prioritys;    }    public void clear() {        for (int i = 1; i < count; i++) {            data[i] = null;        }        count = 0;    }    public int length() {        return count;    }}


请注意,为了性能,使用的距离均是曼哈顿距离。
我们简单的写个程序测一下性能,测试代码如下:

/** * KD树k近邻搜索 */public class KDTreeTest {    public double mandist(Double[] p1,Double[] p2){        double dist = 0.0;        for(int i=0;i<p1.length;i++){            dist += Math.abs(p1[i]-p2[i]);        }        return dist;    }    public void test(){        List<Double> list1 = new ArrayList<Double>();        List<Double> list2 = new ArrayList<Double>();        int k=100;        System.out.println("intput size:"+k);        int dimension = 4;        //初始化100个4维变量        List<Double[]> list = new ArrayList<Double[]>();        for(int i=0;i<k;i++){            Double[] arr = new Double[dimension];            for(int j=0;j<4;j++){                arr[j] = Math.random();            }            list.add(arr);        }        System.out.println("****************kdtree********************");        long time1 = System.currentTimeMillis();        //计算每个点最近的第3个点        KDTree<Integer> kdTree = new KDTree<Integer>(dimension);        for(int i=0;i<k;i++){            double[] curDouble = new double[dimension];            int index = -1;            for(Double item:list.get(i)){                curDouble[++index] = item;            }            kdTree.insert(curDouble,i);        }        int nearest = 8;        for(int i=0;i<k;i++){            double[] curDouble = new double[dimension];            int index = -1;            for(Double item:list.get(i)){                curDouble[++index] = item;            }            List<Double> distance = kdTree.nearestDistance(curDouble, nearest + 1);            Collections.sort(distance);            //System.out.println(distance.get(distance.size() - 1));            list1.add(distance.get(distance.size()-1));        }        long time2 = System.currentTimeMillis();        System.out.println((time2-time1)+"ms");        System.out.println("****************kdtree********************");        System.out.println("****************normal********************");        long time3 = System.currentTimeMillis();        //计算每个点最近的第3个点        for(int i=0;i<k;i++){            List<Double> distance = new ArrayList<Double>();            for(int j=0;j<k;j++){                distance.add(mandist(list.get(i), list.get(j)));            }            Collections.sort(distance);            //System.out.println(distance.get(nearest));            list2.add(distance.get(nearest));        }        long time4 = System.currentTimeMillis();        System.out.println((time4-time3)+"ms");        System.out.println("****************normal********************");        boolean same = true;        for(int i=0;i<list1.size();i++){            if(!list1.get(i).equals(list2.get(i))){                same = false;                break;            }        }        if(same){            System.out.println("result is same");        }else{            System.out.println("result not same");        }    }    public static void main(String[] args){        new KDTreeTest().test();    }}

各个数据量级别的比对如下(kdtree和硬查的方法):




10万级别时,普通硬查方法已经几分钟无法算出结果了。


60万级别时,我们单看kdtree:


那么如何用在spark程序里呢?用法如下:

double eps = Double.MAX_VALUE;        final KDTree<Integer> kdtree = new KDTree<Integer>(4);        scala.reflect.ClassTag<KDTree<Integer>> curClassTag = scala.reflect.ClassTag$.MODULE$.apply(KDTree.class);        try{            JavaRDD<Row> rowRdd = df.select("fwd_ppf", "fwd_bpp", "recv_ppf", "recv_bpp").toJavaRDD();            List<Row> list = rowRdd.collect();            for(int i=0;i<list.size();i++){                kdtree.insert(new double[]{                        Double.valueOf(String.valueOf(list.get(i).get(0))),                        Double.valueOf(String.valueOf(list.get(i).get(1))),                        Double.valueOf(String.valueOf(list.get(i).get(2))),                        Double.valueOf(String.valueOf(list.get(i).get(3)))                },i);            }            //生成kdtree广播变量            final Broadcast<KDTree<Integer>> broadCast = sqlContext.sparkContext().broadcast(kdtree,curClassTag);            JavaRDD<Double> sortRDD = rowRdd.map(new Function<Row, Double>() {                public Double call(Row row) throws Exception {                    double distance = Double.MAX_VALUE;                    try {                        //找到第minPts近的节点                        List<Double> list = broadCast.getValue().nearestDistance(new double[]{                                Double.valueOf(String.valueOf(row.get(0))),                                Double.valueOf(String.valueOf(row.get(1))),                                Double.valueOf(String.valueOf(row.get(2))),                                Double.valueOf(String.valueOf(row.get(3)))                        }, minPts + 1);                        Collections.sort(list);                        distance = list.get(list.size() - 1);                    } catch (Exception ex) {                        logger.error("", ex);                    }                    return distance;                }            });            sortRDD.persist(StorageLevel.MEMORY_AND_DISK());


spark运行结果如下:



目前的弊端为,构造树时需要collect数据,我们算一下driver需要承担的数据量:

4个double数据,6000万情况下,所占字节数:


所以driver端的内存最好配得多一些,我觉得4G比较保险。

0 0
原创粉丝点击