[DLX] [NOIP2009] 靶形数独

来源:互联网 发布:淘宝自动确认多久到账 编辑:程序博客网 时间:2024/05/21 19:11

终于写完了这道题。

我所用的方法是 DLX,即 Dancing Links X algorithm。这是一个如何的算法呢?其所用即:O(1) 恢复链表,完成搜索剪枝。

接下来看“精确覆盖”问题:

对于一个 01 矩阵,选择若干行,使得矩阵的每一列都有且仅有一个 1。

怎么做?很显然这是 NP 问题,方法只有搜索。而在搜索算法中,有一种专为此而设的算法,即 Dancing Links X。

将 01 矩阵用一个双向循环十字链表阵(即 Dancing links)表示,每一列都有一个列头,第一列的前端链有一个表头,所有元素都有指针指向其所在列的列头。

显然,如果选取一行包含于解中,则该行所在列上的其他元素都不能选。正因为此,既然不能选,那么保留其所在行也没有意义(因为选不全),所以可以一并删去。

所以选取一行,则删去许多。这样便达到了剪枝的效果。

关于链表元素的删除,我这里还有一篇关于搜索链表优化的文章,大概是 Dancing Links 的基础。

接下来说 sudoku 一题。如何将其转化为精确覆盖呢?

显然,每个位置,每个元素对每行,每个元素对每列,每个元素对每个九宫格只能选一次。而选完 81 个数之后,每个数都对应了这些条件。

那么,就可以这样来构造精确覆盖模型:

以每个位置,每个元素对每行,每个元素对每列,每个元素对每个九宫格为列(最多共 324 列);

以所选的每个数为行;

双向循环十字链表中的每一个元素对应上所述。

显然,这个问题就解决了。而且根本不需要加什么搜索剪枝,因为 DLX 可以说是自动剪枝加减少分枝数。

不过比较令人头疼的是, DLX 构造模型时所需要的预处理是相当复杂的,但是这实在没有办法。。。。。

Code :

#include <cstdio>#include <cstdlib>#include <cstring>#include <climits>#include <iostream>#include <algorithm>typedef long long int64;typedef unsigned int uint;typedef unsigned long long uint64;#define swap(a, b, t) ({t _ = (a); (a) = (b); (b) = _;})#define MAX(a, b, t) ({t _ = (a), __ = (b); _ > __ ? _ : __;})#define MIN(a, b, t) ({t _ = (a), __ = (b); _ < __ ? _ : __;})#define maintype int#define max(a, b) (MAX(a, b, maintype))#define min(a, b) (MIN(a, b, maintype))#define maxs 12#define maxn 8005#define getpos(i, j) (((i) - 1) / 3 * 3 + ((j) - 1) / 3 + 1)#define abs(a) ({int _____ = a; _____ < 0 ? - _____ : _____;})#define getcost(i, j) ({int ___ = abs(i - 5), ____ = abs(j - 5); 10 - max(___, ____);})int tot, ans1, ans2, total, a[maxs][maxs];struct node{node * l, * r, * u, * d, * c; int s;} vess[maxn];typedef bool visit[maxs][maxs];typedef node * sudoku[maxs][maxs];visit vv, vl, vr, vp;node * head, * tail;sudoku mv, ml, mr, mp;void remove(node * p){    p->l->r = p->r, p->r->l = p->l;    for (node * i = p->d; i != p; i = i->d)        for (node * j = i->r; j != i; j = j->r)            j->u->d = j->d, j->d->u = j->u, -- j->c->s;}void resume(node * p){    for (node * i = p->u; i != p; i = i->u)        for (node * j = i->l; j != i; j = j->l)            j->u->d = j->d->u = j, ++ j->c->s;    p->l->r = p->r->l = p;}void dfs(int now, int ans){    if (head->r == head)    {        if (now == total) return (void)(ans2 = max(ans2, ans));        return;    }    node * p = head->r;    int mini = p->s;    for (node * i = p->r; i != head; i = i->r)        if (i->s < mini)            mini = i->s, p = i;    remove(p);    for (node * i = p->d; i != p; i = i->d)    {        for (node * j = i->r; j != i; j = j->r)            remove(j->c);        dfs(now + 1, ans + i->s);        for (node * j = i->l; j != i; j = j->l)            resume(j->c);    }    resume(p);}node * newnode(node * tar, int s){    node * p = vess + ++ tot;    p->u = tar->u, tar->u->d = p;    p->d = tar, tar->u = p;    p->l = p->r = p, p->c = tar;    return ++ tar->s, p->s = s, p;}node * newhead(){    tail->r = vess + ++ tot;    tail->r->l = tail, tail = tail->r;    tail->u = tail->d = tail->c = tail;    return tail;}void prepare(){    head = tail = vess;    head->u = head->d = head->c = head;    for (int i = 1; i <= 9; ++ i)        for (int j = 1; j <= 9; ++ j)            if (not vv[i][j])                mv[i][j] = newhead();    for (int i = 1; i <= 9; ++ i)        for (int j = 1; j <= 9; ++ j)            if (not vl[i][j])                ml[i][j] = newhead();    for (int i = 1; i <= 9; ++ i)        for (int j = 1; j <= 9; ++ j)            if (not vr[i][j])                mr[i][j] = newhead();    for (int i = 1; i <= 9; ++ i)        for (int j = 1; j <= 9; ++ j)            if (not vp[i][j])                mp[i][j] = newhead();    tail->r = head, head->l = tail;    for (int k = 1; k <= 9; ++ k)        for (int i = 1; i <= 9; ++ i)            for (int j = 1; j <= 9; ++ j)            {                if (vv[i][j] or vl[k][i] or vr[k][j]) continue;                int P = getpos(i, j), s = k * getcost(i, j);                if (vp[k][P]) continue;                node * o = newnode(mv[i][j], s);                node * p = newnode(ml[k][i], s);                node * q = newnode(mr[k][j], s);                node * r = newnode(mp[k][P], s);                o->r = p, p->r = q, q->r = r, r->r = o;                r->l = q, q->l = p, p->l = o, o->l = r;            }}int main(){    freopen("sudoku.in", "r", stdin);    freopen("sudoku.out", "w", stdout);        for (int i = 1; i <= 9; ++ i)        for (int j = 1; j <= 9; ++ j)            if (scanf("%d", & a[i][j]), a[i][j])            {                int k = a[i][j], p = getpos(i, j), c = getcost(i, j);                ans1 += k * c, total += vv[i][j] = vl[k][i] = vr[k][j] = vp[k][p] = 1;            }    prepare(), total = 81 - total, dfs(0, 0);    ans2 ? printf("%d\n", ans1 + ans2) : puts("-1");        return 0;}