LeetCode 4. Median of Two Sorted Arrays

来源:互联网 发布:影视后期网络培训学校 编辑:程序博客网 时间:2024/06/07 14:43

            题意大概是找出两组已经有序的数组合并成的数组的中位数。

            本题最简单的做法是利用归并排序的思想,合并两个数组,然后直接找到中间那个数即可。设第一个数组为a,长度为n,第二个数组为b,长度为m,则时间复杂度为O(n+m)。

            但这样做是不够的,题目要求时间复杂度至少要达到O(logn)级别才可以。

            一开始我想到的做法是,每次把数组对半分,通过相关的判断将每个数组中不存在中位数的那一半舍弃。这样时间复杂度就能达到O(n+m)。但“相关的判断”其实不容易做到,我并没有找到可行的方法。

            经过课上老师的讲解和查阅资料,我发现这一题可以转换成第k小的数的问题。先考虑(n+m)为奇数的情况。找到中位数的问题转换成找到第(n+m+1)/2小的数即可。

            而解决第k小的数的问题有这样的算法:每次取a数组的第x个数,取b数组的第y个数,这里我们取x = y = k/2 。如果某个数组的元素个数不足,则由另一个数组补齐,要使得x + y = k 。

            接下来比较这两个数字,如果两个数相等,则他们的值就是我们要找的第k小的数。这并不难理解,a的第x个数字后面的数字和b的第y个数字后面的数字不会小于这个值,所以a的前x个数字和b的前y个数字一起构成了最小的k个数。如果a数组的第x个数字大于b数组的第y个数字,则将b数组的前y个值全部舍弃。因为这y个值都不可能是第k小的数,第k小的数只可能是a的第x个数字,或者b的第y个数字后面的数字。反之亦然。判断完成之后,修改k值(需要减去舍弃的数目,因为舍弃的数字全是小于第k小的数字的),重复判断。

            而n+m为偶数的情况下,还需算出第(n+m)/2+1小的数字,然后两个数字求出算数平均值即可。

            代码使用递归来实现。在最好的情况下每次会有k/2 个数字被舍弃,而k = (n+m)/2。则递推公式为 T(n+m) = O(1) + T( 3/4*(n+m) 。使用大师定理,算法时间复杂度为O(log(n+m)) 。代码如下:

            

# include <iostream># include <vector># include <cstring># include <cstdio>using namespace std ;class Solution {public:    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {        int n = nums1.size(), m = nums2.size() ;                double ans = this_dfs(0, n, 0, m, (n+m+1)/2, nums1, nums2) ;         if ((n+m)%2==0) {        double x = this_dfs(0, n, 0, m, (n+m)/2+1, nums1, nums2 ) ;        ans = (x + ans)/2 ;        }        return ans ;    }    double this_dfs(int x0, int y0, int x1, int y1, int k, vector<int>& nums1, vector<int>& nums2) {            if (y0 <= x0) return 1.0*nums2[x1+k-1] ;    if (y1 <= x1) return 1.0*nums1[x0+k-1] ;        if (k == 1) return min(nums1[x0], nums2[x1]) ;            int z0 = x0 + k/2 - 1 ;    int z1 = x1 + (k-k/2) - 1 ;        if (z0 >= y0) {    z1 += z0 - y0 + 1 ;    z0 = y0 - 1 ;    }    if (z1 >= y1) {    z0 += z1 - y1 + 1 ;    z1 = y1 - 1 ;    }         if (nums1[z0] == nums2[z1]) return 1.0*nums1[z0] ;    if (nums1[z0] > nums2[z1]) {    return this_dfs(x0, y0, z1+1, y1, k-(z1-x1+1), nums1, nums2) ;    }    if (nums1[z0] < nums2[z1]) {    return this_dfs(z0+1, y0, x1, y1, k-(z0-x0+1), nums1, nums2) ;    }    return 0;    } };

0 0
原创粉丝点击