[LeetCode] 4. Median of Two Sorted Arrays

来源:互联网 发布:java服务器开发与c 编辑:程序博客网 时间:2024/06/06 08:45

题目链接: https://leetcode.com/problems/median-of-two-sorted-arrays/description/

Description

There are two sorted arrays nums1 and nums2 of size m and n respectively.

Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

Example 1:

nums1 = [1, 3]nums2 = [2]The median is 2.0

Example 2:

nums1 = [1, 2]nums2 = [3, 4]The median is (2 + 3)/2 = 2.5

解题思路

对于一个数组来说,寻找中值时会把奇数和偶数长度分开来考虑:

  • 奇数长度:返回正中间的一个数
  • 偶数长度:返回中间两个数的平均值

对于两个已排序数组来说,计算中值也是如此。

设两个数组分别为 AB,长度分别为 mn,其中 m <= n

  • m + n 为奇数:返回两个数组中第 (m + n) / 2 + 1 小的数
  • m + n 为偶数:返回两个数组中第 (m + n) / 2(m + n) / 2 + 1 小的数

因此,上述问题可以归约寻为找两个已排序数组的第 p 小的数,下面将对这个问题进行详细讨论。

考虑两个数组 AB,数组下标从 0 开始,

A = [1, 3],     m = 2B = [0, 2, 4, 5],  n = 3

找第 4 小的数字,可以先将数组 A 从中间分割为 A1 = [1], A2 = [3],然后将数组 B 从某处分割为 B1B2 两部分,使得 A1B1 的长度之和为 3,即 B1 = [0, 2, 4], B2 = [5],记两个数组分割点的下标为 k1k2,则 k1 = 0, k2 = 2

比较 A[k1]B[k2] 的大小关系,发现 A[k1] = A[0] = 1, B[k2] = B[2] = 4, A[k1] < B[k2],此时是否能判断 B[k2] 就是第 4 小的数字?当然不能,因为可能 A[k1 + 1]B[k2] 小,这样的话 B[k2] 就不可能是目标解,就像这个样例,A[k1 + 1] = A[1] = 3, A[1] < B[2]

此时,我们可以判定 A1 分小了,B2 分大了,所以需要将 k1 调大,对应的 k2 调小,获得新的分割 A1 = [1, 3], A2 = []; B1 = [0, 2], B2 = [4, 5]。再对 A[k1]B[k2] 进行比较,发现 A[k1] = A[1] = 3, B[k2] = B[1] = 2, A[k1] > B[k2],同上此时也不能判断 A[k1] 就是目标解,需要比较 B[k2 + 1]A[k1],发现 B[k2 + 1] = B[2] = 4, B[k2 + 1] > A[k1],现在可以确定 A[k1] = 3 就是第 3 小的数了。


对于上述步骤,归总下来即为先将数组 A 在下标 k1 处分割为两部分 A1A2A[k1] 归到 A1 中,同时将数组 B 在下标 k2 处分割为两部分 B1B2B[k2] 归到 B1 中,使得 A1B1 的长度之和为 p

比较 A[k1]B[k2] 的大小关系,

  • A[k1] >= B[k2]A[k1] <= B[k2 + 1],则 A[k1] 为解;
  • A[k1] >= B[k2]A[k1] > B[k2 + 1],则左移 k1(右移 k2);
  • A[k1] < B[k2]A[k1 + 1] >= B[k2],则 B[k2] 为解;
  • A[k1] < B[k2]A[k1 + 1] < B[k2],则右移 k1(左移 k2)。

这里的左移和右移 k1 就可以通过二分搜索的思想来处理。

由于是在较短的数组中进行二分搜索,所以时间复杂度为 O(log(min(m , n)),空间复杂度为 O(1)

Code

实际的代码写的比较冗余,k1k2 的取值要比较注意,因为可能会出现数组越界的情况。

class Solution {public:    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {        double res, m1, m2;        int m = nums1.size(), n = nums2.size();        int total = m + n;        if (total & 1) {            if (m == 0)                res = nums2[n >> 1];            else if (n == 0)                res = nums1[m >> 1];            else if (m < n)                res = findPthSmall(nums1, nums2, (total >> 1) + 1);            else                res = findPthSmall(nums2, nums1, (total >> 1) + 1);        } else {            if (m == 0) {                m1 = nums2[(n - 1) >> 1];                m2 = nums2[n >> 1];            } else if (n == 0) {                m1 = nums1[(m - 1) >> 1];                m2 = nums1[m >> 1];            } else if (m < n) {                m1 = findPthSmall(nums1, nums2, (total >> 1));                m2 = findPthSmall(nums1, nums2, (total >> 1) + 1);            } else {                m1 = findPthSmall(nums2, nums1, (total >> 1));                m2 = findPthSmall(nums2, nums1, (total >> 1) + 1);            }            res = (m1 + m2) / 2;        }        return res;    }    // 必须满足 nums1.size() <= nums2.size() 此函数才能正常执行,虽然可以用下面注释的语句来完善,但是提交上去发现效率会下降    int findPthSmall(vector<int>& nums1, vector<int>& nums2, int p) {        int m = nums1.size(), n = nums2.size();        // if (m > n) return findPthSmall(nums2, nums1, p);        int left = 0, right = m;        int k1 = min(p >> 1, m >> 1);        int k2 = p - 2 - k1;        // 保证初始 k2 <= n - 1         if (k2 >= n) {            k1 += k2 - (n - 1);            k2 -= k2 - (n - 1);        }        while (true) {            if (k1 < 0) {                if (nums2[k2] <= nums1[0]) {                    return nums2[k2];                } else {                    left = k1 + 1;                    if (k1 < m - 1) {                        k1 = (k1 + right) >> 1;                    } else {                        k1 += 1;                    }                }            } else if (k2 < 0) {                if (nums1[k1] <= nums2[0]) {                    return nums1[k1];                } else {                    right = k1 - 1;                    if (k1 > 0) {                        k1 = (left + k1) >> 1;                    } else {                        k1 -= 1;                    }                }            } else {                if (nums1[k1] >= nums2[k2]) {                    if ((k2 + 1 == n) || (k2 + 1 < n && nums1[k1] <= nums2[k2 + 1])) {                        return nums1[k1];                    } else {                        right = k1 - 1;                        if (k1 > 0) {                            k1 = (left + k1) >> 1;                        } else {                            k1 -= 1;                        }                    }                } else {                    if ((k1 + 1 == m) || (k1 + 1 < m && nums2[k2] <= nums1[k1 + 1])) {                        return nums2[k2];                    } else {                        left = k1 + 1;                        if (k1 < m - 1) {                            k1 = left == right ? left : (k1 + right) >> 1;                        } else {                            k1 += 1;                        }                    }                }            }            k2 = p - 2 - k1;        }    }};