八数码 poj1077 Eight(A*、IDA*)

来源:互联网 发布:mysql insert触发器 编辑:程序博客网 时间:2024/05/19 03:20

八数码主要参考:刘亚宁学长的八数码八境界 等

A*主要参考:百度百科,初识A*算法,深入A*算法 等

IDA*主要参考:http://blog.csdn.net/urecvbnkuhbh_54245df/article/details/5856756,http://blog.csdn.net/nomad2/article/details/6562140 等


人工智能课的作业,借此研究下八数码,及相关的搜索算法

八数码(九宫格)是一个空间搜索的问题,而且空间只有9!(362 880)的大小,只是10e5的数量级。

八数码的解决的方案由很多:

(1)非启发性的搜索:暴力广搜,双向广搜,迭代加深的深搜;

(2)启发性搜索:A*,IDA*;(启发性算法有局部择优搜索和最好优先搜索,A*和IDA*都属于最好优先搜索,局部择优搜索有爬山算法,不过还没有研究

(3)基于概率的随机搜索;遗传算法(比较有名的还有蚁群算法,模拟退火算法等智能搜索,不过没有研究


本文主要说明A*、IDA*的搜索解法。

下面是搜索可能用到的问题的分析处理

(1)状态表示每一个状态用一个二维数组记录当前的八数码位置信息,并用记录其它属性


(2)判重方法,hash方法。

1)比较直接的是hash成字符串在,set中判重,不过字符串的操作和set的处理,效率可能要低些

2)将九宫格的hash成十进制(或九进制的表示),不过表示的实际的数据范围要大于真实空间(9! ),因为有一些数据是不可能出现在九宫格中的,

比如数据:(111111111) 是不可能出现在九宫格中

3)利用康托展开,康拓展开计算了,一个排列在所有的全排列的名次。比如4个数的一个排列3 2 1 4,其在4的全排列的名次是 2 * (3!)+ 1 * (2!) + 0 * (1!) + 0 *(0!);其中第i个数是ai:则康拓展开的ai处的值 = (排列中ai后面比ai小的数的个数) * (ai后面数的个数的阶乘);这样就完全没有空间的浪费,所以采用此法。


(3)关于A*算法。A*算法是一个启发性算法,最好优先搜索算法。关键在于其估价函数,

状态n 的估价值为:f‘(n)= g'(n)+ h'(n);

g’(n)表示从初始状态开始到状态n的最短路径值(最小消耗值),h'(n)表示重状态n到终止状态的最短路经值的启发值。由于f‘(n)实现不可知,所以用f(n)来代替 

f(n)= g(n)+ h(n);

g(n)表示初始状态到状态n的实际路径值(消耗值)(代表搜索历史信息),是已知的,h(n)表示状态n到终止状态的消耗估计值,包含着搜索的启发信息。

满足条件:1)g(n)> =g'(n)通常是满足的

    2)h(n) <= h'(n),满足此条件时,可以保证最优解

h(n)的取值可以是,状态n和终止状态中不同位置数字的个数,也可以是没有就位数字到相应位置的曼哈顿距离之和(选择此法)。

关于估价函数的理解,还有待加深。。。

关于A*的理解。A*和之前的广度优先搜索,和优先队列的搜索(最短路的dij算法)很相似。它们搜索的区别就在于估价函数不同而已。

A*既包含历史信息,也包含启发信息,广搜和优先对列的搜索则只有历史信息,没有启发信息。

广搜的估价函数值就是其搜索的深度(或实际的路径值,只不过边值是1)(历史信息,也可以说是最短路径值,其实仔细考虑也是状态图所决定的,可以说是状态的属性),是A*的特例;

优先对列的搜索的估价函数则是从初始状态到n状态的实际路径值,只不过边值不一定是1而已(历史信息,也可以说最短路径值),也可以说是A*的特例。

当然还有待加深理解。。。


(4)A*解法的过程见:百度百科

数据结构Open表,Close表

Open表即优先队列或最小堆,用priority_queue时不能实现修改值,则直接插入即可;自己实现的最小堆,可以实现修改值,效率更高。(还没有实现最小堆的方法

Close表即hash_tab[],利用康托展开后,hash到一个数组中,实现判重。


(5)IDA*,迭代加深的A*,或者说是迭代加深的深度优先搜索。此法不用判重,不用优先队列。只是普通的迭代加深的深度优先搜索+利用股价函数的剪枝,枚举估价值MaxH,利用估价函数剪枝,当节点估价函数的值大于MaxH。关键在于剪枝。

不需要判重的原因是,估价函数中包含了历史信息(深度信息),当超过规定的枚举深度MaxH,即剪去。

还有的一个剪枝是,不走相邻的重复路(不如,走left后走right)

值得注意的是,相对于A*,其不仅方法简单,而且,空间使用要小很多

(6)优先对列的使用是该注意下了。


具体代码:

A*:

#include <iostream>#include <cstdio>#include <cstring>#include <string>#include <cmath>#include <algorithm>#include <cstdlib>#include <vector>#include <queue>using namespace std;///Data!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!struct node {    int tab[3][3];///table    int r, c;///0的位置    int hash_val;///hash值    node* pre;    int op;///path    int f, g;///估价函数    bool operator<(const node &a) const {        return f > a.f;    }}st, ed;node t[370000];int tot;int hash_tab[370000];///362880priority_queue<node>open;int fn[10];int ed_map[10][2];int dir_i[4] = {0, 0, 1, -1};///r, l, u, dint dir_j[4] = {1, -1, 0, 0};char print_op[4] = {'r', 'l', 'd', 'u'};///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!void out(node a);///Calc!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int get_hash(node a) {    int ret;    ret = 0;    int num = 8;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            int x = 0;            for (int jj = j + 1; jj < 3; jj++)                if (a.tab[i][jj] < a.tab[i][j]) x++;///            for (int ii = i + 1; ii < 3; ii++)                for (int jj = 0; jj < 3; jj++)                    if (a.tab[ii][jj] < a.tab[i][j]) x++;            ret += fn[num] * x;            num--;        }    return ret;}int get_f(node a, int g) {///f = g + h;    int h = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++)            if (a.tab[i][j])///不考虑零                h += abs(i - ed_map[a.tab[i][j]][0]) + abs(j - ed_map[a.tab[i][j]][1]);    return g + h;}void change(node &tmp, node a, int nextr, int nextc, int i, int idx) {    for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++)        tmp.tab[i][j] = a.tab[i][j];    swap(tmp.tab[nextr][nextc], tmp.tab[a.r][a.c]);    tmp.hash_val = get_hash(tmp);    tmp.r = nextr; tmp.c = nextc;    tmp.pre = &t[idx];///!!!    tmp.op = i;    tmp.g = a.g + 1;    tmp.f = get_f(tmp, tmp.g);}int check(int i, int j) {    if (i > 2 || i < 0 || j > 2 || j < 0) return 0;    return 1;}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!///Input()Init()void input(node &st) {    char ch;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            scanf(" %c", &ch);            if (ch <= '9' && ch > '0') st.tab[i][j] = ch - '0';            else { st.r = i; st.c = j; st.tab[i][j] = 0; }        }}void In(char s[], node &st){    char ch;    int next = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            while (s[next] == ' ') next++;            ch = s[next++];            if (ch <= '9' && ch > '0') st.tab[i][j] = ch - '0';            else { st.r = i; st.c = j; st.tab[i][j] = 0; }        }}void pre(){    fn[0] = 1;    for (int i = 1; i < 9; i++) fn[i] = i * fn[i - 1];    for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++)        ed.tab[i][j] = (i * 3) + j + 1;    ed.tab[2][2] = 0;    ed.hash_val = get_hash(ed);    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++)        if (ed.tab[i][j]) {                ed_map[ed.tab[i][j]][0] = i; ed_map[ed.tab[i][j]][1] = j;        }}void init() {    memset(hash_tab, 0, sizeof(hash_tab));    while (!open.empty()) open.pop();    tot = 0;    st.f = get_f(st, 0);    st.g = 0;    st.hash_val = get_hash(st);    open.push(st);    hash_tab[st.hash_val] = 1;}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!void path(node* a) {    if (a->hash_val == st.hash_val) return ;    path(a->pre);    printf("%c", print_op[a->op]);}void out_table(node a) {    for (int i = 0; i < 3; i++) {        for (int j = 0; j < 3; j++)        cout << a.tab[i][j] << ' ';        cout <<endl;    }    cout <<endl;}///astarvoid astar() {    int s = 0;    int nextr, nextc;    node ans = st;    int fla = 0;    if (st.hash_val != ed.hash_val)    while (!open.empty()) {        node a = open.top(); open.pop();        t[++tot] = a;///        for (int i = 0; i < 4; i++) {            nextr = a.r + dir_i[i];            nextc = a.c + dir_j[i];            if (check(nextr, nextc)) {                node tmp;                change(tmp, a, nextr, nextc, i, tot);                if (hash_tab[tmp.hash_val]) continue;                if (tmp.hash_val == ed.hash_val) { fla = 1; ans = tmp; break; };                open.push(tmp);            }        }        if (fla) break;        hash_tab[a.hash_val] = 1;    }    path(&ans);    puts("");}///main!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int get_preval(node a) {    int ret = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            if (!a.tab[i][j]) continue;            int x = 0;            for (int jj = j + 1; jj < 3; jj++)                    if (a.tab[i][jj] && a.tab[i][jj] < a.tab[i][j]) x++;///            for (int ii = i + 1; ii < 3; ii++)                for (int jj = 0; jj < 3; jj++)                    if (a.tab[ii][jj] && a.tab[ii][jj] < a.tab[i][j]) x++;            ret += x;        }    return ret % 2;}bool pre_solve() {    return (get_preval(st) ^ get_preval(ed));}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int main() {    char ss[12];    while (gets(ss))    {        pre();        In(ss, st);//        input(st);        init();    //    input(ed);    //    out_table(st);    //    out_table(ed);        if (pre_solve()) puts("unsolvable");        else astar();    }    return 1;}/*1 2 3 x 4 6 7 5 8x 2 3 1 4 6 7 5 81 2 3 x 4 6 7 5 81 2 3 4 x 6 7 5 81 2 3 4 5 x 7 8 61 2 3 4 x 8 7 6 5*/

另种优先队列的使用和路径记录法:

#include <iostream>#include <cstdio>#include <cstring>#include <string>#include <cmath>#include <algorithm>#include <cstdlib>#include <vector>#include <queue>using namespace std;///Data!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!struct node {    int tab[3][3];///table    int r, c;///0的位置    int hash_val;///hash值    int pre;    int op;///path    int f, g;///估价函数    bool operator<(const node &a) const {        return f > a.f;    }};node t[370000];int tot;struct cmp{    bool operator()(int x, int y)    {        return t[x].f > t[y].f;    }};int st, ed;int hash_tab[370000];///362880priority_queue<int, vector<int>, cmp>open;int fn[10];int ed_map[10][2];int dir_i[4] = {0, 0, 1, -1};///r, l, u, dint dir_j[4] = {1, -1, 0, 0};char print_op[4] = {'r', 'l', 'd', 'u'};///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!void out(node a);///Calc!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int get_hash(node a) {    int ret;    ret = 0;    int num = 8;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            int x = 0;            for (int jj = j + 1; jj < 3; jj++)                if (a.tab[i][jj] < a.tab[i][j]) x++;///            for (int ii = i + 1; ii < 3; ii++)                for (int jj = 0; jj < 3; jj++)                    if (a.tab[ii][jj] < a.tab[i][j]) x++;            ret += fn[num] * x;            num--;        }    return ret;}int get_f(node a, int g) {///f = g + h;    int h = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++)            if (a.tab[i][j])///不考虑零                h += abs(i - ed_map[a.tab[i][j]][0]) + abs(j - ed_map[a.tab[i][j]][1]);    return g + h;}void change(int tmp, int a, int nextr, int nextc, int i) {    for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++)        t[tmp].tab[i][j] = t[a].tab[i][j];    swap(t[tmp].tab[nextr][nextc], t[tmp].tab[t[a].r][t[a].c]);    t[tmp].hash_val = get_hash(t[tmp]);    t[tmp].r = nextr; t[tmp].c = nextc;    t[tmp].pre = a;///!!!    t[tmp].op = i;    t[tmp].g = t[a].g + 1;    t[tmp].f = get_f(t[tmp], t[tmp].g);}int check(int i, int j) {    if (i > 2 || i < 0 || j > 2 || j < 0) return 0;    return 1;}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!///Input()Init()void input(node &a) {    char ch;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            scanf(" %c", &ch);            if (ch <= '9' && ch > '0') a.tab[i][j] = ch - '0';            else { a.r = i; a.c = j; a.tab[i][j] = 0; }        }}void In(char s[], node &st){    char ch;    int next = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            while (s[next] == ' ') next++;            ch = s[next++];            if (ch <= '9' && ch > '0') st.tab[i][j] = ch - '0';            else { st.r = i; st.c = j; st.tab[i][j] = 0; }        }}void pre(){    tot = 0;    st = tot++;    ed = tot++;    fn[0] = 1;    for (int i = 1; i < 9; i++) fn[i] = i * fn[i - 1];    for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++)        t[ed].tab[i][j] = (i * 3) + j + 1;    t[ed].tab[2][2] = 0;    t[ed].hash_val = get_hash(t[ed]);    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++)        if (t[ed].tab[i][j]) {                ed_map[t[ed].tab[i][j]][0] = i; ed_map[t[ed].tab[i][j]][1] = j;        }}void init() {    memset(hash_tab, 0, sizeof(hash_tab));    while (!open.empty()) open.pop();    t[st].f = get_f(t[st], 0);    t[st].g = 0;    t[st].hash_val = get_hash(t[st]);    open.push(st);    hash_tab[t[st].hash_val] = 1;}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!void path(int x) {    if (t[x].hash_val == t[st].hash_val) return ;    path(t[x].pre);    printf("%c", print_op[t[x].op]);}void out_table(node a) {    for (int i = 0; i < 3; i++) {        for (int j = 0; j < 3; j++)        cout << a.tab[i][j] << ' ';        cout <<endl;    }    cout <<endl;}///astarvoid astar() {    int s = 0;    int nextr, nextc;    int ans = st;    int fla = 0;    if (t[st].hash_val != t[ed].hash_val)    while (!open.empty()) {        int a = open.top(); open.pop();        for (int i = 0; i < 4; i++) {            nextr = t[a].r + dir_i[i];            nextc = t[a].c + dir_j[i];            if (check(nextr, nextc)) {                int tmp = tot++;                change(tmp, a, nextr, nextc, i);                if (hash_tab[t[tmp].hash_val]) continue;                if (t[tmp].hash_val == t[ed].hash_val) { fla = 1; ans = tmp; break; };                open.push(tmp);            }        }        if (fla) break;        hash_tab[t[a].hash_val] = 1;    }    path(ans);    puts("");}///main!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int get_preval(node a) {    int ret = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            if (!a.tab[i][j]) continue;            int x = 0;            for (int jj = j + 1; jj < 3; jj++)                    if (a.tab[i][jj] && a.tab[i][jj] < a.tab[i][j]) x++;///            for (int ii = i + 1; ii < 3; ii++)                for (int jj = 0; jj < 3; jj++)                    if (a.tab[ii][jj] && a.tab[ii][jj] < a.tab[i][j]) x++;            ret += x;        }    return ret % 2;}bool pre_solve() {    return (get_preval(t[st]) ^ get_preval(t[ed]));}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int main() {    char ss[12];    while(gets(ss))    {        pre();//        input(t[st]);        In(ss, t[st]);        init();    //    input(ed);    //    out_table(st);    //    out_table(ed);        if (pre_solve()) puts("unsolvable");        else astar();    }    return 1;}/*1 2 3 x 4 6 7 5 8x 2 3 1 4 6 7 5 81 2 3 x 4 6 7 5 81 2 3 4 x 6 7 5 81 2 3 4 5 x 7 8 61 2 3 4 x 8 7 6 5*/

IDA*法:

#include <iostream>#include <cstdio>#include <cstring>#include <string>#include <cmath>#include <algorithm>#include <cstdlib>#include <vector>#include <queue>#include <stack>using namespace std;///Data!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!struct node {    int tab[3][3];///table    int r, c;///0的位置    int f, g;///估价函数}st, ed;int ed_map[10][2];int dir_i[4] = {0, 0, 1, -1};///r, l, u, dint dir_j[4] = {1, -1, 0, 0};char print_op[4] = {'r', 'l', 'd', 'u'};int MaxH;//stack<int>sk;int ans[5000000];int ansn;///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!///Calc!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int get_f(node a, int g) {///f = g + h;    int h = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++)            if (a.tab[i][j])///不考虑零                h += abs(i - ed_map[a.tab[i][j]][0]) + abs(j - ed_map[a.tab[i][j]][1]);    return g + 2 * h;}void change(node &tmp, node a, int nextr, int nextc, int i) {    for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++)        tmp.tab[i][j] = a.tab[i][j];    swap(tmp.tab[nextr][nextc], tmp.tab[a.r][a.c]);    tmp.r = nextr; tmp.c = nextc;    tmp.g = a.g + 1;    tmp.f = get_f(tmp, tmp.g);}int check(int i, int j) {    if (i > 2 || i < 0 || j > 2 || j < 0) return 0;    return 1;}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!///Input()Init()void input(node &st) {    char ch;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            scanf(" %c", &ch);            if (ch <= '9' && ch > '0') st.tab[i][j] = ch - '0';            else { st.r = i; st.c = j; st.tab[i][j] = 0; }        }}void In(char s[], node &st){    char ch;    int next = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            while (s[next] == ' ') next++;            ch = s[next++];            if (ch <= '9' && ch > '0') st.tab[i][j] = ch - '0';            else { st.r = i; st.c = j; st.tab[i][j] = 0; }        }}void init(){    for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++)        ed.tab[i][j] = (i * 3) + j + 1;    ed.tab[2][2] = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++)        if (ed.tab[i][j]) {                ed_map[ed.tab[i][j]][0] = i; ed_map[ed.tab[i][j]][1] = j;        }    st.f = get_f(st, 0);    st.g = 0;}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!void out_table(node a) {    for (int i = 0; i < 3; i++) {        for (int j = 0; j < 3; j++)        cout << a.tab[i][j] << ' ';        cout <<endl;    }    cout <<endl;}bool isfind(node a){    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++)        if (a.tab[i][j] != ed.tab[i][j]) return false;    return true;}bool dfs(node a, int d){    if (a.f > MaxH) return false;    if (isfind(a)) return true;    for (int i = 0; i < 4; i++)    {        if (d != i && i / 2 == d / 2) continue;///相邻不往返        int nextr = a.r + dir_i[i];        int nextc = a.c + dir_j[i];        if (check(nextr, nextc))        {            node tmp;            change(tmp, a, nextr, nextc, i);            if (dfs(tmp, i)) {                    ans[ansn++] = i;//                sk.push(i);                return true;            }        }    }    return false;}void ida_star(){//    while (!sk.empty()) sk.pop();    ansn = 0;    MaxH = st.f;    while (!dfs(st, 6)) MaxH+=1;    for (int i = ansn - 1; i >= 0; i--) putchar(print_op[ans[i]]);//    while(!sk.empty()) printf("%c", print_op[sk.top()]), sk.pop();    puts("");}///main!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int get_preval(node a) {    int ret = 0;    for (int i = 0; i < 3; i++)        for (int j = 0; j < 3; j++) {            if (!a.tab[i][j]) continue;            int x = 0;            for (int jj = j + 1; jj < 3; jj++)                    if (a.tab[i][jj] && a.tab[i][jj] < a.tab[i][j]) x++;///            for (int ii = i + 1; ii < 3; ii++)                for (int jj = 0; jj < 3; jj++)                    if (a.tab[ii][jj] && a.tab[ii][jj] < a.tab[i][j]) x++;            ret += x;        }    return ret % 2;}bool pre_solve() {    return (get_preval(st) ^ get_preval(ed));}///!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!int main() {    char ss[12];    while (gets(ss))    {        In(ss, st);//        input(st);    //    input(ed);        init();    //    out_table(st);    //    out_table(ed);        if (pre_solve()) puts("unsolvable");        else ida_star();    }    return 1;}/*1 2 3 x 4 6 7 5 8x 2 3 1 4 6 7 5 81 2 3 x 4 6 7 5 81 2 3 4 x 6 7 5 81 2 3 4 5 x 7 8 61 2 3 4 x 8 7 6 5*/




1 0
原创粉丝点击