xjoj325最大矩形

来源:互联网 发布:java个人简历项目经验 编辑:程序博客网 时间:2024/05/22 00:45

Description

一个N*M的矩阵,每个格子里面有个整数( 绝对值不大与10 ) ,每个子矩阵( 至少包含一个元素 )的价值就是它所包含的格子内的数的和。 现在求两个不相交的子矩阵(不包含相同的格子),使得他们的价值的乘积最大。
例如: N=3 , M=4,矩阵如图所示:
234513244321最大子矩阵值乘积为288。(左边两列的和为16,右边两列的和为18,结果为16*18=288)。

Input

第一行有两个数字n, m ( n, m < 100)。以后的n行,每行有m个整数。

Output

输出文件只有一个数,即两不相交子矩阵价值乘积的最大值。

Sample Input

1 7-9 -9 8 8 1 7 -4

Sample Output

128


分析:首先,可以想到最简单的方法就是穷举所有子矩阵的组合,第一个子矩阵,有O(n^2*m^2)种选择,第二个子矩阵有O(n^2*m^2)种选择,求每个子矩阵中所有元素的和需要O(n*m),这样复杂度是O(n^5*m^5),根据题目中的条件,n,m<100,当n,m在100这个数量级的时候,总的时间将达到10^20数量级,显然是太慢了。

作为改进,可以尝试存储已经计算出来的结果,这样算法分成两步:第一步,计算每个子矩阵各自的元素和,第二步,利用之前计算的结果,穷举每一种可能。第二步的时间复杂度是O(n^4*m^4),而第一步用动态规划算法来实现的话,其时间复杂度为O(n^2*m^2).因此最终的时间复杂度为O(n^4*m^4)+O(n^2*m^2)=O(n^4*m^4).这种情况下需要多大的计算量呢?如果n和m都在100级别,那么有10^16数量级,虽然缩小了10^4,但是这个规模依然还是太大了。

进一步改进,由于题目中只求“最大”,那么其实只要找到最大的子矩阵和最小的子矩阵就行了,将来的结果必定是这些最大值或最小值之间的乘积。同时,还想到了“分治"思想:任意两个不相交的子矩阵,必定能用一条垂直线或一条水平线分开来,于是我们只需要求每一条垂直线左边的最大和最小子矩阵值,以及垂直线右边的最大和最小子矩阵值,和每一条水平线两边的相应的值。对于每一条线,它左边的最值和右边的最值的乘积的最大值,就是以它作为分界线时能得到的最大乘积。而分界线有多少呢?只有O(m+n)条,每条分界线两边的最大/最小值的计算是可以重用的,这样每条分界线我们要计算O(n^3+m^3)次,从而得到总的计算量是O((m+n)*(n^3+m^3))次,当m,n在100级别时,其计算时间是在4 0000 0000,即只有亿级别,这个对于计算机来说就不难了。

事实上,我们还可以再做优化,使得时间复杂度达到O((m+n)*(n^2+m^2))级别,只是由于上述方法得到的代码已经Accept了,所以也没有继续改进。

代码如下:


import java.util.Scanner;/** * Problem : xjoj 325 最大矩形 * Accepted * 思路:分治和动态规划结合。 * @author wwf * */public class Main {static final long INT_MIN = -25000000001L;static final long INT_MAX = 25000000001L;/** * @param args */public static void main(String[] args) {// TODO Auto-generated method stubScanner in = new Scanner(System.in);int n = in.nextInt();int m = in.nextInt();int[][] matrix = new int[n][m];for (int i = 0; i < n; i++) {for (int j = 0; j < m; j++) {matrix[i][j] = in.nextInt();}}long[] right_max = new long[n];long[] right_min = new long[n];long[] left_max = new long[n];long[] left_min = new long[n];long[] up_max = new long[m];long[] up_min = new long[m];long[] down_max = new long[m];long[] down_min = new long[m];long[][] max_pre = new long[m][m];long[][] min_pre = new long[m][m];for (int i = n - 1; i >= 0; i--) {if (i == n - 1) {right_max[i] = INT_MIN;right_min[i] = INT_MAX;} else {right_max[i] = right_max[i + 1];right_min[i] = right_min[i + 1];}for (int k = 0; k < m; k++) {for (int l = k; l < m; l++) {long sum = 0;for (int index = k; index <= l; index++) {sum += matrix[i][index];}long local_max, local_min;if (i < n - 1 && max_pre[k][l] > 0) {local_max = sum + max_pre[k][l];} else {local_max = sum;}if (i < n - 1 && min_pre[k][l] < 0) {local_min = sum + min_pre[k][l];} else {local_min = sum;}right_max[i] = max(local_max, right_max[i]);right_min[i] = min(local_min, right_min[i]);max_pre[k][l] = local_max;min_pre[k][l] = local_min;}}}for (int i = 0; i < n; i++) {if (i == 0) {left_max[i] = INT_MIN;left_min[i] = INT_MAX;} else {left_max[i] = left_max[i - 1];left_min[i] = left_min[i - 1];}for (int k = 0; k < m; k++) {for (int l = k; l < m; l++) {long sum = 0;for (int index = k; index <= l; index++) {sum += matrix[i][index];}long local_max, local_min;if (i > 0 && max_pre[k][l] > 0) {local_max = sum + max_pre[k][l];} else {local_max = sum;}if (i > 0 && min_pre[k][l] < 0) {local_min = sum + min_pre[k][l];} else {local_min = sum;}//System.out.println("local_max="+local_max);left_max[i] = max(local_max, left_max[i]);left_min[i] = min(local_min, left_min[i]);max_pre[k][l] = local_max;min_pre[k][l] = local_min;}}}max_pre = new long[n][n];min_pre = new long[n][n];for (int k = 0; k < m; k++) {if (k == 0) {up_max[k] = INT_MIN;up_min[k] = INT_MAX;} else {up_max[k] = up_max[k - 1];up_min[k] = up_min[k - 1];}for (int i = 0; i < n; i++) {for (int j = i; j < n; j++) {long sum = 0;for (int index = i; index <= j; index++) {sum += matrix[index][k];}long local_max, local_min;if (k > 0 && max_pre[i][j] > 0) {local_max = sum + max_pre[i][j];} else {local_max = sum;}if (k > 0 && min_pre[i][j] < 0) {local_min = sum + min_pre[i][j];} else {local_min = sum;}up_max[k] = max(local_max, up_max[k]);up_min[k] = min(local_min, up_min[k]);max_pre[i][j] = local_max;min_pre[i][j] = local_min;}}}for (int k = m - 1; k >= 0; k--) {if (k == m - 1) {down_max[k] = INT_MIN;down_min[k] = INT_MAX;} else {down_max[k] = down_max[k + 1];down_min[k] = down_min[k + 1];}for (int i = 0; i < n; i++) {for (int j = i; j < n; j++) {long sum = 0;for (int index = i; index <= j; index++) {sum += matrix[index][k];}long local_max, local_min;if (k < m - 1 && max_pre[i][j] > 0) {local_max = sum + max_pre[i][j];} else {local_max = sum;}if (k < m - 1 && min_pre[i][j] < 0) {local_min = sum + min_pre[i][j];} else {local_min = sum;}down_max[k] = max(local_max, down_max[k]);down_min[k] = min(local_min, down_min[k]);max_pre[i][j] = local_max;min_pre[i][j] = local_min;}}}long result = INT_MIN;for (int i = 0; i < n - 1; i++) {result = max(result, left_max[i] * right_max[i + 1]);result = max(result, left_max[i] * right_min[i + 1]);result = max(result, left_min[i] * right_max[i + 1]);result = max(result, left_min[i] * right_min[i + 1]);}for (int k = 0; k < m - 1; k++) {result = max(result, up_max[k] * down_max[k + 1]);result = max(result, up_max[k] * down_min[k + 1]);result = max(result, up_min[k] * down_max[k + 1]);result = max(result, up_min[k] * down_min[k + 1]);}System.out.println(result);}static long max(long a, long b) {return a > b ? a : b;}static long min(long a, long b) {return a < b ? a : b;}static void print_2_dim_long_array(long[][]array){for(int i=0;i<array.length;i++){for(int j=0;j<array[i].length;j++){System.out.print(array[i][j]+"\t\t\t");}System.out.println();}}static void print_long_array(long[]array){for(long i:array){System.out.print(i+"\t");}System.out.println();}}