挑战程序竞赛系列(31):4.5剪枝

来源:互联网 发布:旅行收纳袋 知乎 编辑:程序博客网 时间:2024/05/22 04:40

挑战程序竞赛系列(31):4.5剪枝

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

  • POJ 1011: Sticks
  • POJ 2046: Gap
  • POJ 3134: Power Calculus

POJ 1011: Sticks

变态的DFS搜索,需要剪枝否则TLE,初始版本如下:

    void solve() {        while (true){            int n = ni();            if (n == 0) break;            int[] sticks = new int[n];            int min = 0;            int sum = 0;            int max = 0;            for (int i = 0; i < n; ++i){                int len = ni();                max = Math.max(max, len);                sum += len;                sticks[i] = len;            }            Arrays.sort(sticks);            Set<Integer> mem = new HashSet<Integer>();            for (int i = n; i >= 1; --i){                if (sum % i == 0){                    min = sum / i;                    if (min >= max && dfs(sticks, min, 0, new boolean[n], mem)) break;                }            }            out.println(min);        }    }    public boolean dfs(int[] sticks, int min, int sum, boolean[] visited, Set<Integer> mem){        if (mem.size() == sticks.length){            return min == sum;        }        if (sum > min) return false;        if (sum == min){            if (dfs(sticks, min, 0, visited, mem)) return true;            else return false;        }        for (int i = sticks.length - 1; i >= 0; --i){            int rem = min - sum;            if (!visited[i] && sticks[i] <= rem){                visited[i] = true;                mem.add(i);                if (dfs(sticks, min, sum + sticks[i], visited, mem)){                    return true;                }                else{                    visited[i] = false;                    mem.remove(i);                }            }        }        return false;    }

代码细节可以忽略,visited和mem可以合并,做了一些简单的剪枝处理,但始终超时。思路是遍历各种组合,且当所有元素被使用后,看是否能够找到所有长度一致的木棒。

遍历超时很大一部分的原因在于dfs中有个for循环,对于重复长度的棒子过滤的不够干净,浪费了大量的搜素资源。我们可以采用map对长度进行统计,这样重复长度的棒子大可不必搜索,省时省力。

依旧遍历,对每种可能的组合进行搜索,搜索时记录拼接完成的棒子个数,个数 * 可能长度 = 总和时,遍历结束。

具体细节可以参考博文:http://www.hankcs.com/program/algorithm/poj-1011-sticks.html,不作赘述。

代码如下:

import java.io.File;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStream;import java.io.PrintWriter;import java.util.Arrays;import java.util.HashSet;import java.util.InputMismatchException;import java.util.Set;public class Main{    InputStream is;    PrintWriter out;    String INPUT = "./data/judge/201707/P1011.txt";    int[] in;    int candicate;    void solve() {        while (true){            int n = ni();            if (n == 0) break;            in = new int[51];            finish = false;            int sum = 0;            candicate = 0;            int max = 0;            for (int i = 0; i < n; ++i){                int len = ni();                max = Math.max(max, len);                sum += len;                in[len] ++;            }            candicate = max;            while (true){                if (sum % candicate == 0){                    check(sum / candicate, candicate, max);                }                if (finish) break;                ++candicate;            }            out.println(candicate);        }    }    boolean finish;    public void check(int count, int len, int plen){        --in[plen];        if (count == 0){            finish = true;        }        if (!finish){            len -= plen; //剩余长度            if (len != 0){                int nextPlen = Math.min(len, plen);                for (; nextPlen > 0; --nextPlen){                    if (in[nextPlen] != 0){                        check(count, len, nextPlen);                    }                }            }            else{                int max = 50;                while (max > 0 && in[max] == 0) --max;                check(count - 1, candicate, max); //当前剩余棒子的最大长度            }        }        ++in[plen];    }    void run() throws Exception {        is = oj ? System.in : new FileInputStream(new File(INPUT));        out = new PrintWriter(System.out);        long s = System.currentTimeMillis();        solve();        out.flush();        tr(System.currentTimeMillis() - s + "ms");    }    public static void main(String[] args) throws Exception {        new Main().run();    }    private byte[] inbuf = new byte[1024];    public int lenbuf = 0, ptrbuf = 0;    private int readByte() {        if (lenbuf == -1)            throw new InputMismatchException();        if (ptrbuf >= lenbuf) {            ptrbuf = 0;            try {                lenbuf = is.read(inbuf);            } catch (IOException e) {                throw new InputMismatchException();            }            if (lenbuf <= 0)                return -1;        }        return inbuf[ptrbuf++];    }    private boolean isSpaceChar(int c) {        return !(c >= 33 && c <= 126);    }    private int skip() {        int b;        while ((b = readByte()) != -1 && isSpaceChar(b))            ;        return b;    }    private double nd() {        return Double.parseDouble(ns());    }    private char nc() {        return (char) skip();    }    private String ns() {        int b = skip();        StringBuilder sb = new StringBuilder();        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '                                    // ')            sb.appendCodePoint(b);            b = readByte();        }        return sb.toString();    }    private char[] ns(int n) {        char[] buf = new char[n];        int b = skip(), p = 0;        while (p < n && !(isSpaceChar(b))) {            buf[p++] = (char) b;            b = readByte();        }        return n == p ? buf : Arrays.copyOf(buf, p);    }    private char[][] nm(int n, int m) {        char[][] map = new char[n][];        for (int i = 0; i < n; i++)            map[i] = ns(m);        return map;    }    private int[] na(int n) {        int[] a = new int[n];        for (int i = 0; i < n; i++)            a[i] = ni();        return a;    }    private int ni() {        int num = 0, b;        boolean minus = false;        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))            ;        if (b == '-') {            minus = true;            b = readByte();        }        while (true) {            if (b >= '0' && b <= '9') {                num = num * 10 + (b - '0');            } else {                return minus ? -num : num;            }            b = readByte();        }    }    private long nl() {        long num = 0;        int b;        boolean minus = false;        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))            ;        if (b == '-') {            minus = true;            b = readByte();        }        while (true) {            if (b >= '0' && b <= '9') {                num = num * 10 + (b - '0');            } else {                return minus ? -num : num;            }            b = readByte();        }    }    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;    private void tr(Object... o) {        if (!oj)            System.out.println(Arrays.deepToString(o));    }}

提供一组变态数据,可以自己测测来玩玩,上述算法需要3s搜索出答案。

6440 40 30 35 35 26 15 40 40 40 40 40 40 40 40 40 40 40 40 40 4040 40 43 42 42 41 10 4 40 40 40 40 40 40 40 40 40 40 40 40 4040 25 39 46 40 10 4 40 40 37 18 17 16 15 40 40 40 40 40 40 40 400ans:454[3503ms]

POJ 只需128ms走完全部测试数据,数据有点水啊。

POJ 2046: Gap

一道模拟题,用BFS广搜就好了,关键抓住填入空格的规则,只有一种情况,只允许填入左侧的下一个数字,所以在当前board下只会出现四种状态,没有什么搜索策略,按照轮次搜即可。

BFS的一个好处在于,能够以最短的距离搜到终止状态,也是此题的关键。不过还需要注意,当我们定义board的状态时,可以从整体出发,需要重写hashCode和equal方法,方便记录状态的访问情况,好题。

步骤:

  • 定义状态
  • 考虑状态的终止条件
  • 考虑状态的切换规则
  • 重写hashCode和equals方法

代码如下:

import java.io.File;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStream;import java.io.PrintWriter;import java.util.Arrays;import java.util.HashSet;import java.util.InputMismatchException;import java.util.LinkedList;import java.util.Queue;import java.util.Set;public class Main{    InputStream is;    PrintWriter out;    String INPUT = "./data/judge/201708/P2046.txt";    class Game{        int[][] board = new int[4][8];        int turn;        public Game(int[][] board){            this.board = board;            this.turn = 0;            int[] y = find(11);            swap(new int[]{0, 0}, y);            y = find(21);            swap(new int[]{1, 0}, y);            y = find(31);            swap(new int[]{2, 0}, y);            y = find(41);            swap(new int[]{3, 0}, y);        }        public Game(Game newGame){            for (int i = 0; i < 4; ++i){                for (int j = 0; j < 8; ++j){                    this.board[i][j] = newGame.board[i][j];                }            }            this.turn = newGame.turn;        }        public boolean canFill(int i, int j){            if (board[i][j] != 0) return false;            if (board[i][j - 1] != 0 && (board[i][j - 1] % 10) != 7) return true;            return false;        }        public boolean done(){            for (int i = 0; i < 4; ++i){                if (board[i][7] != 0) return false;            }            for (int i = 0; i < 4; ++i){                for (int j = 0; j < 7; ++j){                    if (board[i][j] != (i + 1) * 10 + (j + 1)) return false;                }            }            return true;        }        public void fillGap(int i, int j){            int key = board[i][j - 1] + 1;            int[] pos = find(key);            swap(new int[]{i, j}, pos);            this.turn ++;        }        @Override        public boolean equals(Object obj) {            if (obj instanceof Game){                Game that = (Game)obj;                for (int i = 0; i < 4; ++i){                    for (int j = 0; j < 8; ++j){                        if (board[i][j] != that.board[i][j]) return false;                    }                }                return true;            }            else return false;        }        public int[] find(int key){            for (int i = 0; i < 4; ++i){                for (int j = 0; j < 8; ++j){                    if (board[i][j] == key) return new int[]{i, j};                }            }            return new int[]{-1, -1};        }        public void swap(int[] x, int[] y){            int tmp = board[x[0]][x[1]];            board[x[0]][x[1]] = board[y[0]][y[1]];            board[y[0]][y[1]] = tmp;        }        @Override        public String toString() {            StringBuilder sb = new StringBuilder();            for (int i = 0; i < 4; ++i){                for (int j = 0; j < 8; ++j){                    sb.append(board[i][j] + (j + 1 == 8 ? "\n" : " "));                }            }            return sb.toString();        }        @Override        public int hashCode() {            int hash = 0;            for (int i = 0; i < 4; ++i){                for (int j = 1; j < 8; ++j){                    hash += board[i][j];                    hash <<= 1;                }            }            return hash;        }    }    void solve() {        int T = ni();        while (T --> 0){            int[][] board = new int[4][8];            for (int i = 0; i < 4; ++i){                for (int j = 1; j < 8; ++j){                    board[i][j] = ni();                }            }            Game game = new Game(board);            Queue<Game> queue = new LinkedList<Game>();            Set<Game> visited = new HashSet<Game>();            if (game.done()){                out.println(0);                continue;            }            queue.offer(game);            int ans = -1;            boolean end = false;            outer: while (!queue.isEmpty() && !end){                Game gg = queue.poll();                if (visited.contains(gg)) continue;                visited.add(gg);                for (int i = 0; i < 4; ++i){                    for (int j = 1; j < 8; ++j){                        if (gg.canFill(i, j)){                            Game tmp = new Game(gg);                            tmp.fillGap(i, j);                            if (tmp.done()){                                ans = tmp.turn;                                end = true;                                continue outer;                            }                            else queue.offer(tmp);                        }                    }                }            }            out.println(ans);        }    }    void run() throws Exception {        is = oj ? System.in : new FileInputStream(new File(INPUT));        out = new PrintWriter(System.out);        long s = System.currentTimeMillis();        solve();        out.flush();        tr(System.currentTimeMillis() - s + "ms");    }    public static void main(String[] args) throws Exception {        new Main().run();    }    private byte[] inbuf = new byte[1024];    public int lenbuf = 0, ptrbuf = 0;    private int readByte() {        if (lenbuf == -1)            throw new InputMismatchException();        if (ptrbuf >= lenbuf) {            ptrbuf = 0;            try {                lenbuf = is.read(inbuf);            } catch (IOException e) {                throw new InputMismatchException();            }            if (lenbuf <= 0)                return -1;        }        return inbuf[ptrbuf++];    }    private boolean isSpaceChar(int c) {        return !(c >= 33 && c <= 126);    }    private int skip() {        int b;        while ((b = readByte()) != -1 && isSpaceChar(b))            ;        return b;    }    private double nd() {        return Double.parseDouble(ns());    }    private char nc() {        return (char) skip();    }    private String ns() {        int b = skip();        StringBuilder sb = new StringBuilder();        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '                                    // ')            sb.appendCodePoint(b);            b = readByte();        }        return sb.toString();    }    private char[] ns(int n) {        char[] buf = new char[n];        int b = skip(), p = 0;        while (p < n && !(isSpaceChar(b))) {            buf[p++] = (char) b;            b = readByte();        }        return n == p ? buf : Arrays.copyOf(buf, p);    }    private char[][] nm(int n, int m) {        char[][] map = new char[n][];        for (int i = 0; i < n; i++)            map[i] = ns(m);        return map;    }    private int[] na(int n) {        int[] a = new int[n];        for (int i = 0; i < n; i++)            a[i] = ni();        return a;    }    private int ni() {        int num = 0, b;        boolean minus = false;        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))            ;        if (b == '-') {            minus = true;            b = readByte();        }        while (true) {            if (b >= '0' && b <= '9') {                num = num * 10 + (b - '0');            } else {                return minus ? -num : num;            }            b = readByte();        }    }    private long nl() {        long num = 0;        int b;        boolean minus = false;        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))            ;        if (b == '-') {            minus = true;            b = readByte();        }        while (true) {            if (b >= '0' && b <= '9') {                num = num * 10 + (b - '0');            } else {                return minus ? -num : num;            }            b = readByte();        }    }    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;    private void tr(Object... o) {        if (!oj)            System.out.println(Arrays.deepToString(o));    }}

POJ 3134: Power Calculus

第一次遇到迭代加深算法,有点难以理解。刚开始使用BFS,但发现细节处理上有问题。此题有些关键地方,比如同一轮生成的解,不能结合使用,只能使用前几层的解和当前层解的组合,或许可以如此想象,在构造轮次时,只有一条链,这种构造路径难道不是DFS?没错,就是它,但是何时终止呢?

剪枝算法告诉我们,每个给定的n都有一个上界,就拿快速幂的例子来说,举例13,至多也就这些操作:

13-1=1212/2=66/2=33-1=22/2=1

于是我们迭代找解的时候可以根据此上界进行剪枝,对数据预处理下,有上界函数:

    public int upper(int n){        int cnt = 0;        while (n > 0){            if ((n & 1) != 0){                cnt ++;            }            n >>= 1;            cnt ++;        }        return cnt - 2;    }

接着就是构一条生成路径了,采用DFS,巧妙之处在于这种DFS刚好能够模拟这种构造状态,神奇,比如:

1 -> 0表示x需要耗费0次,初始状态那么自然地:2 = 1 + 1表示: x^2 = x * x所以有1生成了2那么3怎么来? 1 + 2 = 3但是当前层还会有 2 + 2 = 4但神奇的是,1 + 2 = 32 + 2 = 4不会同一时刻遍历,而是分为两次dfs调用。这就解决了之前BFS的一个bug,呵呵哒。

代码如下:

import java.io.File;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStream;import java.io.PrintWriter;import java.util.Arrays;import java.util.InputMismatchException;public class Main{    InputStream is;    PrintWriter out;    String INPUT = "./data/judge/201708/P3134.txt";    int MAX_N = 1024;    int MAX_D = 20;    int[] exp = new int[MAX_D];    int[] ans = new int[MAX_N];    void solve(){        Arrays.fill(exp, 1);        for (int i = 2; i < MAX_N; ++i){            ans[i] = upper(i);        }        dfs(0);        while (true){            int n = ni();            if (n == 0) break;            out.println(ans[n]);        }    }    public void dfs(int d) {        if (d > MAX_D) {            return;         }        for (int i = 0; i <= d; i++) {            exp[d + 1] = exp[i] + exp[d]; // 乘法            if (exp[d + 1] < MAX_N && ans[exp[d + 1]] >= d + 1) { //这层的解要是被更新的话,继续更新下下层的解                ans[exp[d + 1]]  = d + 1; //更新解                dfs(d + 1);            }            exp[d + 1] = exp[d] - exp[i]; // 除法            if (exp[d + 1] > 0 && ans[exp[d + 1]] >= d + 1) {                ans[exp[d + 1]] = d + 1;                dfs(d + 1);            }        }    }    public int upper(int n){        int cnt = 0;        while (n > 0){            if ((n & 1) != 0){                cnt ++;            }            n >>= 1;            cnt ++;        }        return cnt - 2;    }    void run() throws Exception {        is = oj ? System.in : new FileInputStream(new File(INPUT));        out = new PrintWriter(System.out);        long s = System.currentTimeMillis();        solve();        out.flush();        tr(System.currentTimeMillis() - s + "ms");    }    public static void main(String[] args) throws Exception {        new Main().run();    }    private byte[] inbuf = new byte[1024];    public int lenbuf = 0, ptrbuf = 0;    private int readByte() {        if (lenbuf == -1)            throw new InputMismatchException();        if (ptrbuf >= lenbuf) {            ptrbuf = 0;            try {                lenbuf = is.read(inbuf);            } catch (IOException e) {                throw new InputMismatchException();            }            if (lenbuf <= 0)                return -1;        }        return inbuf[ptrbuf++];    }    private boolean isSpaceChar(int c) {        return !(c >= 33 && c <= 126);    }    private int skip() {        int b;        while ((b = readByte()) != -1 && isSpaceChar(b))            ;        return b;    }    private double nd() {        return Double.parseDouble(ns());    }    private char nc() {        return (char) skip();    }    private String ns() {        int b = skip();        StringBuilder sb = new StringBuilder();        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '                                    // ')            sb.appendCodePoint(b);            b = readByte();        }        return sb.toString();    }    private char[] ns(int n) {        char[] buf = new char[n];        int b = skip(), p = 0;        while (p < n && !(isSpaceChar(b))) {            buf[p++] = (char) b;            b = readByte();        }        return n == p ? buf : Arrays.copyOf(buf, p);    }    private char[][] nm(int n, int m) {        char[][] map = new char[n][];        for (int i = 0; i < n; i++)            map[i] = ns(m);        return map;    }    private int[] na(int n) {        int[] a = new int[n];        for (int i = 0; i < n; i++)            a[i] = ni();        return a;    }    private int ni() {        int num = 0, b;        boolean minus = false;        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))            ;        if (b == '-') {            minus = true;            b = readByte();        }        while (true) {            if (b >= '0' && b <= '9') {                num = num * 10 + (b - '0');            } else {                return minus ? -num : num;            }            b = readByte();        }    }    private long nl() {        long num = 0;        int b;        boolean minus = false;        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))            ;        if (b == '-') {            minus = true;            b = readByte();        }        while (true) {            if (b >= '0' && b <= '9') {                num = num * 10 + (b - '0');            } else {                return minus ? -num : num;            }            b = readByte();        }    }    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;    private void tr(Object... o) {        if (!oj)            System.out.println(Arrays.deepToString(o));    }}

这就厉害了,类似于BFS,但能够精确的控制每层解的非法组合。

原创粉丝点击