m维空间里n个点每点最近的第k个点的距离
来源:互联网 发布:java工厂模式概念 编辑:程序博客网 时间:2024/06/08 03:05
题目如图。
m=4,即点均为四维空间的点。
n数目不定,可以理解为几万,几十万甚至上千万。
使用spark计算。资源配置为:executor-cores:6,executor-memory:10G。
解法一:
首先将点的矩阵弄成dataframe(dataframe里每一个Row的内容均为:[uuid,double1,double2,double3,double4])
然后dataframe join自身,然后再map求每一行(每点)与其他点的距离,并返回JavaPairRDD,key为点的uuid,value为欧几里得距离。
然后再groupByKey,这样得到的JavaPairRDD的key为点的uuid,value为iteratorable即是该点与其他点距离的数组。
然后再map,sort每个点的距离iterable,从而得到每个点最近的第k个点与它的距离。
代码如下:
性能:n=10000的情况下,大概需要运行3分钟
瓶颈分析:
1. 计算复杂度为n^2,10000的情况下达到亿级别
2. uuid为字符串,增加了分布式shuffle时移动的数据量
3. 欧几里得距离需要乘方和开发,浮点数计算比较消耗性能
解法2:
1. 将欧几里得距离换为曼哈顿距离
2. uuid使用zipWithUniqueId生成,是long
3. 不使用join,使用RDD的cartesian方法生成
性能:与解法一差异不大
综上可知,主要还是求每个点与其他点的距离(即解法一的join,解法二的cartesian,非常耗时)。
如果要在本质上解决问题,需要将【求每个点与其他点的距离】这一步给剪掉。
k-d树是一种二叉树,它主要是循环的根据curDimension%DimensionNum进行分割子树。
而搜索邻近k个点时,只需要维护一个优先队列二叉式的搜索即可。即省略了计算每个点与其他点的距离这一部分。
代码如下:
HPoint.java
public class HPoint implements Serializable { protected double[] coord; protected HPoint(int n) { coord = new double[n]; } protected HPoint(double[] x) { coord = new double[x.length]; for (int i = 0; i < x.length; ++i) { coord[i] = x[i]; } } protected Object clone() { return new HPoint(coord); } protected boolean equals(HPoint p) { for (int i = 0; i < coord.length; ++i) { if (coord[i] != p.coord[i]) { return false; } } return true; } //曼哈顿距离 protected static double manhanttandist(HPoint x,HPoint y){ double dist = 0; for(int i=0; i < x.coord.length; ++i){ dist += Math.abs(x.coord[i] - y.coord[i]); } return dist; } //平方距离 protected static double sqrdist(HPoint x, HPoint y) { double dist = 0; for (int i = 0; i < x.coord.length; ++i) { double diff = (x.coord[i] - y.coord[i]); dist += (diff * diff); } return dist; } //欧几里得距离 protected static double eucdist(HPoint x, HPoint y) { return Math.sqrt(sqrdist(x, y)); } public String toString() { String s = ""; for (int i = 0; i < coord.length; ++i) { s = s + coord[i] + " "; } return s; }}
HRect.java
public class HRect implements Serializable { protected HPoint min; protected HPoint max; protected HRect(int ndims) { min = new HPoint(ndims); max = new HPoint(ndims); } protected HRect(HPoint vmin, HPoint vmax) { min = (HPoint) vmin.clone(); max = (HPoint) vmax.clone(); } protected Object clone() { return new HRect(min, max); } //返回区域里距离HPoint距离最近的点 protected HPoint closest(HPoint t) { HPoint p = new HPoint(t.coord.length); for (int i = 0; i < t.coord.length; ++i) { if (t.coord[i] <= min.coord[i]) { p.coord[i] = min.coord[i]; } else if (t.coord[i] >= max.coord[i]) { p.coord[i] = max.coord[i]; } else { p.coord[i] = t.coord[i]; } } return p; } //初始化d维度的区域 protected static HRect infiniteHRect(int d) { HPoint vmin = new HPoint(d); HPoint vmax = new HPoint(d); for (int i = 0; i < d; ++i) { vmin.coord[i] = Double.NEGATIVE_INFINITY; vmax.coord[i] = Double.POSITIVE_INFINITY; } return new HRect(vmin, vmax); } public String toString() { return min + "\n" + max + "\n"; }}
KDNode.java
//kd树节点public class KDNode<T> implements Serializable { protected HPoint k; protected KDNode left, right; protected boolean deleted; T v; private KDNode(HPoint key, T val) { k = key; v = val; left = null; right = null; deleted = false; } //插入节点 protected static KDNode ins(HPoint key, Object val, KDNode t, int lev, int K) { if (t == null) { t = new KDNode(key, val); } else if (key.equals(t.k)) { //插入的值与该节点重复;如果该节点被标记为已删除,则将此节点恢复为未删除状态 if (t.deleted) { t.deleted = false; t.v = val; } } else if (key.coord[lev] > t.k.coord[lev]) { t.right = ins(key, val, t.right, (lev + 1) % K, K); } else { t.left = ins(key, val, t.left, (lev + 1) % K, K); } return t; } //搜索节点 protected static KDNode srch(HPoint key, KDNode t, int K) { for (int lev = 0; t != null; lev = (lev + 1) % K) { if (!t.deleted && key.equals(t.k)) { return t; } else if (key.coord[lev] > t.k.coord[lev]) { t = t.right; } else { t = t.left; } } return null; } protected static void rsearch(HPoint lowk, HPoint uppk, KDNode t, int lev, int K, Vector<KDNode> v) { if (t == null) { return; } if (lowk.coord[lev] <= t.k.coord[lev]) { rsearch(lowk, uppk, t.left, (lev + 1) % K, K, v); } int j; for (j = 0; j < K && lowk.coord[j] <= t.k.coord[j] && uppk.coord[j] >= t.k.coord[j]; j++) ; if (j == K) { v.add(t); } if (uppk.coord[lev] > t.k.coord[lev]) { rsearch(lowk, uppk, t.right, (lev + 1) % K, K, v); } } //近邻搜索 protected static void nnbr(KDNode kd, HPoint target, HRect hr, double max_dist_sqd, int lev, int K, NearestNeighborList nnl) { if (kd == null) { return; } int s = lev % K; HPoint pivot = kd.k; double pivot_to_target = HPoint.manhanttandist(pivot, target); HRect left_hr = hr; HRect right_hr = (HRect) hr.clone(); left_hr.max.coord[s] = pivot.coord[s]; right_hr.min.coord[s] = pivot.coord[s]; boolean target_in_left = target.coord[s] < pivot.coord[s]; KDNode nearer_kd; HRect nearer_hr; KDNode further_kd; HRect further_hr; if (target_in_left) { nearer_kd = kd.left; nearer_hr = left_hr; further_kd = kd.right; further_hr = right_hr; } else { nearer_kd = kd.right; nearer_hr = right_hr; further_kd = kd.left; further_hr = left_hr; } nnbr(nearer_kd, target, nearer_hr, max_dist_sqd, lev + 1, K, nnl); KDNode nearest = (KDNode) nnl.getHighest(); double dist_sqd; if (!nnl.isCapacityReached()) { dist_sqd = Double.MAX_VALUE; } else { dist_sqd = nnl.getMaxPriority(); } max_dist_sqd = Math.min(max_dist_sqd, dist_sqd); HPoint closest = further_hr.closest(target); if (Double.valueOf(HPoint.manhanttandist(closest, target)).compareTo(max_dist_sqd) < 0) { if (pivot_to_target < dist_sqd) { nearest = kd; dist_sqd = pivot_to_target; if (!kd.deleted) { nnl.insert(kd, dist_sqd); } if (nnl.isCapacityReached()) { max_dist_sqd = nnl.getMaxPriority(); } else { max_dist_sqd = Double.MAX_VALUE; } } nnbr(further_kd, target, further_hr, max_dist_sqd, lev + 1, K, nnl); KDNode temp_nearest = (KDNode) nnl.getHighest(); double temp_dist_sqd = nnl.getMaxPriority(); if (temp_dist_sqd < dist_sqd) { nearest = temp_nearest; dist_sqd = temp_dist_sqd; } } else if (pivot_to_target < max_dist_sqd) { nearest = kd; dist_sqd = pivot_to_target; } } private static String pad(int n) { String s = ""; for (int i = 0; i < n; ++i) { s += " "; } return s; } private static void hrcopy(HRect hr_src, HRect hr_dst) { hpcopy(hr_src.min, hr_dst.min); hpcopy(hr_src.max, hr_dst.max); } private static void hpcopy(HPoint hp_src, HPoint hp_dst) { for (int i = 0; i < hp_dst.coord.length; ++i) { hp_dst.coord[i] = hp_src.coord[i]; } } protected String toString(int depth) { String s = k + " " + v + (deleted ? "*" : ""); if (left != null) { s = s + "\n" + pad(depth) + "L " + left.toString(depth + 1); } if (right != null) { s = s + "\n" + pad(depth) + "R " + right.toString(depth + 1); } return s; }}
KDTree.java
/** * kd树 */public class KDTree<T> implements java.io.Serializable { //维度 private int m_K; //根节点 private KDNode m_root; //树里的节点个数 private int m_count; //创建一个k维的k-d树 public KDTree(int k) { m_K = k; m_root = null; } //向kd树里插入一个节点 //key是k维的值 //value是节点的标签 public void insert(double[] key, T value) { if (key.length != m_K) { throw new RuntimeException("KDTree: wrong key size!"); } else { m_root = KDNode.ins(new HPoint(key), value, m_root, 0, m_K); } m_count++; } //根据key数值,搜索kd树节点 public Object search(double[] key) { if (key.length != m_K) { throw new RuntimeException("KDTree: wrong key size!"); } KDNode kd = KDNode.srch(new HPoint(key), m_root, m_K); return (kd == null ? null : kd.v); } //删除kd树节点 public void delete(double[] key) { if (key.length != m_K) { throw new RuntimeException("KDTree: wrong key size!"); } else { KDNode t = KDNode.srch(new HPoint(key), m_root, m_K); if (t == null) { throw new RuntimeException("KDTree: key missing!"); } else { t.deleted = true; } m_count--; } } //搜索距离最近的kd树节点 public T nearest(double[] key) { List<T> nbrs = nearest(key, 1); return nbrs.get(0); } //搜索最近的n个kd树节点 public List<T> nearest(double[] key, int n) { if (n < 0 || n > m_count) { throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot" + " be negative or greater than number of nodes (" + m_count + ")."); } if (key.length != m_K) { throw new RuntimeException("KDTree: wrong key size!"); } List<T> nbrs = new ArrayList<T>(n); NearestNeighborList nnl = new NearestNeighborList(n); HRect hr = HRect.infiniteHRect(key.length); double max_dist_sqd = Double.MAX_VALUE; HPoint keyp = new HPoint(key); KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl); for (int i = 0; i < n; ++i) { KDNode<T> kd = (KDNode) nnl.removeHighest(); nbrs.add(kd.v); } return nbrs; } private double mandist(double[] p1,double[] p2){ double dist = 0.0; for(int i=0;i<p1.length;i++){ dist += Math.abs(p1[i]-p2[i]); } return dist; } //搜索最近的n个kd树节点,返回与他们的的距离 public List<Double> nearestDistance(double[] key, int n) { if (n < 0 || n > m_count) { throw new IllegalArgumentException("Number of neighbors (" + n + ") cannot" + " be negative or greater than number of nodes (" + m_count + ")."); } if (key.length != m_K) { throw new RuntimeException("KDTree: wrong key size!"); } List<Double> nbrs = new ArrayList<Double>(n); NearestNeighborList nnl = new NearestNeighborList(n); HRect hr = HRect.infiniteHRect(key.length); double max_dist_sqd = Double.MAX_VALUE; HPoint keyp = new HPoint(key); KDNode.nnbr(m_root, keyp, hr, max_dist_sqd, 0, m_K, nnl); for (int i = 0; i < n; ++i) { KDNode<T> kd = (KDNode) nnl.removeHighest(); nbrs.add(mandist(kd.k.coord,key)); } return nbrs; } public String toString() { return m_root.toString(0); }}
NeareastNeighbor.java
/** * 最近邻居列表,基于优先队列实现 */public class NearestNeighborList implements Serializable { public static int REMOVE_HIGHEST = 1; public static int REMOVE_LOWEST = 2; PriorityQueue m_Queue = null; int m_Capacity = 0; //只保存最近的capacity个邻居 public NearestNeighborList(int capacity) { m_Capacity = capacity; m_Queue = new PriorityQueue(m_Capacity, Double.POSITIVE_INFINITY); } public double getMaxPriority() { if (m_Queue.length() == 0) { return Double.POSITIVE_INFINITY; } return m_Queue.getMaxPriority(); } public boolean insert(Object object, double priority) { if (m_Queue.length() < m_Capacity) { //如果尚未达到capacity个,则直接放入队列 m_Queue.add(object, priority); return true; } if (priority > m_Queue.getMaxPriority()) { //如果优先级比队列里的其他元素都大,则入不了队列 return false; } //移除队列中优先级最大的元素,即队尾元素 m_Queue.remove(); //将新元素插入 m_Queue.add(object, priority); return true; } public boolean isCapacityReached() { return m_Queue.length() >= m_Capacity; } public Object getHighest() { return m_Queue.front(); } public boolean isEmpty() { return m_Queue.length() == 0; } public int getSize() { return m_Queue.length(); } public Object removeHighest() { return m_Queue.remove(); }}
PriorityQueue.java
/** * 优先队列,优先级越低的越在队列前 */public class PriorityQueue implements Serializable { private double maxPriority = Double.MAX_VALUE; private Object[] data; private double[] value; private int count; private int capacity; public PriorityQueue() { init(20); } public PriorityQueue(int capacity) { init(capacity); } public PriorityQueue(int capacity, double maxPriority) { this.maxPriority = maxPriority; init(capacity); } private void init(int size) { capacity = size; data = new Object[capacity + 1]; value = new double[capacity + 1]; value[0] = maxPriority; data[0] = null; } public void add(Object element, double priority) { if (count++ >= capacity) { expandCapacity(); } value[count] = priority; data[count] = element; bubbleUp(count); } public Object remove() { if (count == 0) { return null; } Object element = data[1]; data[1] = data[count]; value[1] = value[count]; data[count] = null; value[count] = 0L; count--; bubbleDown(1); return element; } public Object front() { return data[1]; } public double getMaxPriority() { return value[1]; } private void bubbleDown(int pos) { Object element = data[pos]; double priority = value[pos]; int child; for (; pos * 2 <= count; pos = child) { child = pos * 2; if (child != count) { if (value[child] < value[child + 1]) { child++; } } if (priority < value[child]) { value[pos] = value[child]; data[pos] = data[child]; } else { break; } } value[pos] = priority; data[pos] = element; } private void bubbleUp(int pos) { Object element = data[pos]; double priority = value[pos]; while (value[pos / 2] < priority) { value[pos] = value[pos / 2]; data[pos] = data[pos / 2]; pos /= 2; } value[pos] = priority; data[pos] = element; } private void expandCapacity() { capacity = count * 2; Object[] elements = new Object[capacity + 1]; double[] prioritys = new double[capacity + 1]; System.arraycopy(data, 0, elements, 0, data.length); System.arraycopy(value, 0, prioritys, 0, data.length); data = elements; value = prioritys; } public void clear() { for (int i = 1; i < count; i++) { data[i] = null; } count = 0; } public int length() { return count; }}
我们简单的写个程序测一下性能,测试代码如下:
/** * KD树k近邻搜索 */public class KDTreeTest { public double mandist(Double[] p1,Double[] p2){ double dist = 0.0; for(int i=0;i<p1.length;i++){ dist += Math.abs(p1[i]-p2[i]); } return dist; } public void test(){ List<Double> list1 = new ArrayList<Double>(); List<Double> list2 = new ArrayList<Double>(); int k=100; System.out.println("intput size:"+k); int dimension = 4; //初始化100个4维变量 List<Double[]> list = new ArrayList<Double[]>(); for(int i=0;i<k;i++){ Double[] arr = new Double[dimension]; for(int j=0;j<4;j++){ arr[j] = Math.random(); } list.add(arr); } System.out.println("****************kdtree********************"); long time1 = System.currentTimeMillis(); //计算每个点最近的第3个点 KDTree<Integer> kdTree = new KDTree<Integer>(dimension); for(int i=0;i<k;i++){ double[] curDouble = new double[dimension]; int index = -1; for(Double item:list.get(i)){ curDouble[++index] = item; } kdTree.insert(curDouble,i); } int nearest = 8; for(int i=0;i<k;i++){ double[] curDouble = new double[dimension]; int index = -1; for(Double item:list.get(i)){ curDouble[++index] = item; } List<Double> distance = kdTree.nearestDistance(curDouble, nearest + 1); Collections.sort(distance); //System.out.println(distance.get(distance.size() - 1)); list1.add(distance.get(distance.size()-1)); } long time2 = System.currentTimeMillis(); System.out.println((time2-time1)+"ms"); System.out.println("****************kdtree********************"); System.out.println("****************normal********************"); long time3 = System.currentTimeMillis(); //计算每个点最近的第3个点 for(int i=0;i<k;i++){ List<Double> distance = new ArrayList<Double>(); for(int j=0;j<k;j++){ distance.add(mandist(list.get(i), list.get(j))); } Collections.sort(distance); //System.out.println(distance.get(nearest)); list2.add(distance.get(nearest)); } long time4 = System.currentTimeMillis(); System.out.println((time4-time3)+"ms"); System.out.println("****************normal********************"); boolean same = true; for(int i=0;i<list1.size();i++){ if(!list1.get(i).equals(list2.get(i))){ same = false; break; } } if(same){ System.out.println("result is same"); }else{ System.out.println("result not same"); } } public static void main(String[] args){ new KDTreeTest().test(); }}
各个数据量级别的比对如下(kdtree和硬查的方法):
10万级别时,普通硬查方法已经几分钟无法算出结果了。
60万级别时,我们单看kdtree:
那么如何用在spark程序里呢?用法如下:
double eps = Double.MAX_VALUE; final KDTree<Integer> kdtree = new KDTree<Integer>(4); scala.reflect.ClassTag<KDTree<Integer>> curClassTag = scala.reflect.ClassTag$.MODULE$.apply(KDTree.class); try{ JavaRDD<Row> rowRdd = df.select("fwd_ppf", "fwd_bpp", "recv_ppf", "recv_bpp").toJavaRDD(); List<Row> list = rowRdd.collect(); for(int i=0;i<list.size();i++){ kdtree.insert(new double[]{ Double.valueOf(String.valueOf(list.get(i).get(0))), Double.valueOf(String.valueOf(list.get(i).get(1))), Double.valueOf(String.valueOf(list.get(i).get(2))), Double.valueOf(String.valueOf(list.get(i).get(3))) },i); } //生成kdtree广播变量 final Broadcast<KDTree<Integer>> broadCast = sqlContext.sparkContext().broadcast(kdtree,curClassTag); JavaRDD<Double> sortRDD = rowRdd.map(new Function<Row, Double>() { public Double call(Row row) throws Exception { double distance = Double.MAX_VALUE; try { //找到第minPts近的节点 List<Double> list = broadCast.getValue().nearestDistance(new double[]{ Double.valueOf(String.valueOf(row.get(0))), Double.valueOf(String.valueOf(row.get(1))), Double.valueOf(String.valueOf(row.get(2))), Double.valueOf(String.valueOf(row.get(3))) }, minPts + 1); Collections.sort(list); distance = list.get(list.size() - 1); } catch (Exception ex) { logger.error("", ex); } return distance; } }); sortRDD.persist(StorageLevel.MEMORY_AND_DISK());
spark运行结果如下:
目前的弊端为,构造树时需要collect数据,我们算一下driver需要承担的数据量:
4个double数据,6000万情况下,所占字节数:
所以driver端的内存最好配得多一些,我觉得4G比较保险。
- m维空间里n个点每点最近的第k个点的距离
- LintCode:M-K个最近的点
- 计算二维空间某点的最近k 个点
- K个最近的点
- K个最近的点
- SCU 4313 把一棵树切成每段K个点 (n%k)剩下的点不管
- 【K-D树 K维最近距离的t个点】HDU
- lintcode[612]:k个最近的点
- Lintcode-K个最近的点(#612)
- LintCode K个最近的点
- LintCode: K个最近的点
- k个最近的点(Leetcode答案)
- K个最近的点-LintCode
- 612. K个最近的点[LintCode]
- Lintcode 612. K个最近的点
- [算法] 已知在平面坐标系内有N个点,求离开给定坐标距离最近的10个点
- (阶段四1.4)LA 3708 Graveyard(一个圆圈上有n个点,新加入m个点,求每个点的最小移动距离)
- 找n个点相距最远的k个
- 关于PIXI引擎制作页面小游戏的几个总结
- recyclerview
- uboot整体介绍
- nodejs开发中间件connect-flash
- Android检测升级并下载安装工具类
- m维空间里n个点每点最近的第k个点的距离
- kafka server.properties配置
- jQuery 动画效果
- (3)虚拟机字节码执行引擎
- c++primer 6.15while循环习题!
- 嵌入式技术行业知识
- NAT,Bridge,HostOnly
- html5中的新增知识点
- android app 接入第三方SDK接口层实现思考