简单的遗传算法(Genetic algorithms)-吃豆人

来源:互联网 发布:无主之地2淘宝 编辑:程序博客网 时间:2024/05/21 23:33

遗传算法简介:

一直都在收听卓老板聊科技这个节目,最近播出了一起人工智能的节目,主要讲的是由霍兰提出的遗传算法,在目中详细阐述了一个有趣的小实验:吃豆人。

首先简单介绍下遗传算法:
1:为了解决某个具体的问题,先随机生成若干个解决问题的实体,每个实体解决问题的方式都用“基因”来表示,也就是说,不同的实体拥有不同的基因,那么也对应着不同的解决问题的方案。
2:有了若干实体之后,接下来就是让这些实体来完成这个任务,根据任务的完成情况用相同标准打分。
3:接下来是进化环节,按照得分的高低,得出每个个体被选出的概率,得分越高越容易被选出,先选出两个个体,对其基因进行交叉,再按照设定的概率对其基因进行突变,来生成新个体,不停重复直到生成足够数量的新个体,这便是一次进化过程。按照这个方法不停的进化,若干代之后就能得到理想的个体。

下面简单介绍下吃豆人实验:

吃豆人首先生存在一个10*10个格子组成的矩形空间中,将50个豆子随机放在这100个格子中,每个格子要嘛为空,要嘛就有一颗豆子。吃豆人出生的时候随机出现在一个任意方格中,接下来吃豆人需要通过自己的策略来吃豆子,一共只有200步,吃到一颗+10分,撞墙-5分,发出吃豆子的动作却没吃到豆子-1分。另外吃豆人只能看到自己所在格子和上下左右一共5个格子的情况。

整理一下
吃豆人的所有动作:上移、下移、左移、由移、吃豆、不动、随机移动,一共7种
吃豆人所能观察到的状态:每个格子有,有豆子,无豆子,墙3种状态,而一共有5个格子,那就是3^5=243种状态。

为此,吃豆人个体的基因可以用243长度的基因表示,分别对应所处的243种状态,每个基因有7种情况,分别表示所处状态下产生的反应。

代码

Main.java

public class Main {    public static void main(String[] args) {        Population population = new Population(1000, false);        System.out.println(population);        long count = 1;        while (true){                                       Population newPopulation = Algorithm.evolve(population);            if (count % 5 == 0) {                System.out.println("The " + count + "'s evolve");                System.out.println(newPopulation);              }            population = newPopulation;            count++;                    }    }       }

Individual.java

public class Individual {    //吃豆人一共会有3^5种状态,它能观察的位置一共有上下左右和当前格子,一个共5个,每个格子有墙,豆子,无豆子3种状态。    private static int length = 243;    /*吃豆人一共有7总动作     * 0 :上    4 : 随机移动     * 1 : 左   5 : 吃     * 2 : 下   6 : 不动       * 3 : 右        */    private static byte actionNum = 7;    private byte genes[] = null;    private int fitness = Integer.MIN_VALUE;    public Individual() {        genes = new byte[length];           }    public void generateGenes(){                for (int i = 0; i < length; i++) {            byte gene = (byte) Math.floor(Math.random() * actionNum);            genes[i] = gene;        }    }    public int getFitness() {        if (fitness == Integer.MIN_VALUE) {            fitness = FitnessCalc.getFitnessPall(this);        }        return fitness;    }    public int getLength() {        return length;    }    public byte getGene(int index) {        return genes[index];    }    public void setGene(int index, byte gene) {        this.genes[index] = gene;        fitness = Integer.MIN_VALUE;    }    //状态码的转换:5个3进制位,第一个代表中,第二个代表上,第三个代表右,第四个代表下,第五个代表左    public byte getActionCode(State state) {                int stateCode = (int) (state.getMiddle() * Math.pow(3, 4) + state.getUp() * Math.pow(3, 3) + state.getRight() * Math.pow(3, 2) + state.getDown() * 3 + state.getLeft());        return genes[stateCode];    }    @Override    public String toString() {          StringBuffer bf = new StringBuffer();        for (int i = 0; i < length; i++) {            bf.append(genes[i]);        }        return bf.toString();    }    public static void main(String[] args) {        Individual ind = new Individual();        ind.generateGenes();        System.out.println(ind);        System.out.println(ind.getFitness());        System.out.println(FitnessCalc.getFitnessPall(ind));    }}

State.java

public class State {    //0为墙,1为有豆子,2为无豆子       private byte middle;    private byte up;    private byte right;    private byte down;    private byte left;    public State(byte middle, byte up, byte right, byte down, byte left) {        this.middle = middle;        this.up = up;        this.right = right;        this.down = down;        this.left = left;    }    public byte getMiddle() {        return middle;    }    public void setMiddle(byte middle) {        this.middle = middle;    }    public byte getUp() {        return up;    }    public void setUp(byte up) {        this.up = up;    }    public byte getRight() {        return right;    }    public void setRight(byte right) {        this.right = right;    }    public byte getDown() {        return down;    }    public void setDown(byte down) {        this.down = down;    }    public byte getLeft() {        return left;    }    public void setLeft(byte left) {        this.left = left;    }}

Algorithm.java

public class Algorithm {    /* GA 算法的参数 */    private static final double uniformRate = 0.5; //交叉概率    private static final double mutationRate = 0.0001; //突变概率    private static final int tournamentSize = 3; //淘汰数组的大小    public static Population evolve(Population pop) {        Population newPopulation = new Population(pop.size(), true);        for (int i = 0; i < pop.size(); i++) {        //随机选择两个 优秀的个体            Individual indiv1 = tournamentSelection(pop);            Individual indiv2 = tournamentSelection(pop);                       //进行交叉            Individual newIndiv = crossover(indiv1, indiv2);            newPopulation.saveIndividual(i, newIndiv);          }        // Mutate population  突变        for (int i = 0; i < newPopulation.size(); i++) {            mutate(newPopulation.getIndividual(i));        }           return newPopulation;           }           // 随机选择一个较优秀的个体,用了进行交叉    private static Individual tournamentSelection(Population pop) {        // Create a tournament population        Population tournamentPop = new Population(tournamentSize, true);        //随机选择 tournamentSize 个放入 tournamentPop 中        for (int i = 0; i < tournamentSize; i++) {            int randomId = (int) (Math.random() * pop.size());            tournamentPop.saveIndividual(i, pop.getIndividual(randomId));        }        // 找到淘汰数组中最优秀的        Individual fittest = tournamentPop.getFittest();        return fittest;    }    // 进行两个个体的交叉 。 交叉的概率为uniformRate    private static Individual crossover(Individual indiv1, Individual indiv2) {        Individual newSol = new Individual();        // 随机的从 两个个体中选择         for (int i = 0; i < indiv1.getLength(); i++) {            if (Math.random() <= uniformRate) {                newSol.setGene(i, indiv1.getGene(i));            } else {                newSol.setGene(i, indiv2.getGene(i));            }        }        return newSol;    }    // 突变个体。 突变的概率为 mutationRate    private static void mutate(Individual indiv) {        for (int i = 0; i < indiv.getLength(); i++) {            if (Math.random() <= mutationRate) {                // 生成随机的 0-6                byte gene = (byte) Math.floor(Math.random() * 7);                indiv.setGene(i, gene);            }        }    }}

Population.java

public class Population {    private Individual[] individuals;    public Population(int size, boolean lazy) {        individuals = new Individual[size];        if (!lazy) {            for (int i = 0; i < individuals.length; i++) {                Individual ind = new Individual();                ind.generateGenes();                individuals[i] = ind;            }        }    }    public void saveIndividual(int index, Individual ind) {        individuals[index] = ind;    }    public Individual getIndividual(int index) {        return individuals[index];    }    public Individual getFittest() {        Individual fittest = individuals[0];        // Loop through individuals to find fittest        for (int i = 1; i < size(); i++) {            if (fittest.getFitness() <= getIndividual(i).getFitness()) {                fittest = getIndividual(i);            }        }        return fittest;    }    public Individual getLeastFittest() {        Individual ind = individuals[0];        for (int i = 1; i < size(); i++) {            if (ind.getFitness() > getIndividual(i).getFitness()) {                ind = getIndividual(i);            }        }        return ind;    }    public double getAverageFitness() {        double sum = 0;        for (int i = 0; i < size(); i++) {            sum += individuals[i].getFitness();        }        return sum / size();    }    public int size() {        return individuals.length;    }    @Override    public String toString(){        StringBuffer bf = new StringBuffer();        bf.append("Population size: " + size() + "\n");        bf.append("Max Fitnewss: " + getFittest().getFitness() + "\n");        bf.append("Least Fitness: " + getLeastFittest().getFitness() + "\n");        bf.append("Average Fitness: " + getAverageFitness() + "\n");                return bf.toString();    }    public static void main(String[] args) {        Population population = new Population(8000, false);        System.out.println(population);        }}

MapMgr.java

public class MapMgr {    private static int x = 10;    private static int y = 10;    private static int beanNum = 50;    private static int mapNum = 100;    private static MapMgr manager = null;           private Map[] maps = null;    private MapMgr() {        maps = new Map[mapNum];        for (int i = 0; i < mapNum; i++) {            Map map = new Map(x, y);            map.setBeans(beanNum);            maps[i] = map;        }    }    synchronized public static MapMgr getInstance() {        if (manager == null) manager = new MapMgr();        return manager;    }    public Map getMap(int index) {        Map map = null;        index = index % mapNum;        try {            map = maps[index].clone();        } catch (CloneNotSupportedException e) {            e.printStackTrace();        }        return map;         }    public static void main(String[] args) {        MapMgr mgr = MapMgr.getInstance();        mgr.getMap(1).print();        System.out.println("--------------");        mgr.getMap(2).print();    }}

Map.java

import java.awt.Point;public class Map implements Cloneable{    private int x = -1;    private int y = -1;    private int total = -1;    private byte[][] mapGrid = null;    public Map(int x, int y) {        this.x = x;        this.y = y;        mapGrid = new byte[x][y];        total = x * y;    }    public void setBeans(int num) {        //check num         if (num > total) {            num = total;        }        for (int i = 0; i < num; i++) {            int address, xp, yp;            do{                address = (int) Math.floor((Math.random() * total)); //生成0 - (total-1)的随机数                          xp = address / y;                yp = address % y;                   //System.out.println(xp+ ":" + yp + ":" + address + ":" + total);            } while (mapGrid[xp][yp] != 0);            mapGrid[xp][yp] = 1;                    }    }    public boolean isInMap(int x, int y) {              if (x < 0 || x >= this.x) return false;        if (y < 0 || y >= this.y) return false;             return true;    }    public boolean hasBean(int x, int y) {        boolean ret = mapGrid[x][y] == 0 ? false : true;        return ret;    }    public boolean eatBean(int x, int y) {        if(hasBean(x, y)) {            mapGrid[x][y] = 0;            return true;        }        return false;    }    public Point getStartPoint() {                      int x = (int) Math.floor(Math.random() * this.x);        int y = (int) Math.floor(Math.random() * this.y);               return new Point(x, y);    }    public State getState(Point p) {                byte middle = stateOfPoint(p);        byte up = stateOfPoint(new Point(p.x, p.y - 1));        byte right = stateOfPoint(new Point(p.x + 1, p.y));        byte down = stateOfPoint(new Point(p.x, p.y + 1));        byte left = stateOfPoint(new Point(p.x - 1, p.y));        return new State(middle, up, right, down, left);    }    //0为墙,1为有豆子,2为无豆子    private byte stateOfPoint(Point p) {        byte ret;        if (!isInMap(p.x, p.y)) ret = 0;                    else if (mapGrid[p.x][p.y] == 0) ret =  2;        else ret = 1;        return ret;    }    @Override    public Map clone() throws CloneNotSupportedException {        Map m = (Map) super.clone();        byte[][] mapGrid = new byte[x][y];        for (int i = 0; i < x; i++) {            for (int j = 0; j < y; j++) {                mapGrid[i][j] = this.mapGrid[i][j];            }        }        m.mapGrid = mapGrid;        return m;           }    public void print() {        for (int i = 0; i < y; i++) {            for (int j = 0; j < x; j++) {                System.out.print(mapGrid[j][i]);            }            System.out.println();        }    }    public static void main(String[] args) {        Map m = new Map(10, 5);        Map m1 = null;        try {            m1 = m.clone();        } catch (CloneNotSupportedException e) {            // TODO Auto-generated catch block            e.printStackTrace();        }        m.setBeans(40);        m.print();        m1.setBeans(15);        m1.print();    }}

FitnessCalc

import java.awt.Point;import java.util.concurrent.Callable;import java.util.concurrent.ExecutionException;import java.util.concurrent.FutureTask;public class FitnessCalc {    /*动作结果说明:     * 撞墙:-5分     * 吃到豆子:10分     * 吃空了:-1分     * 其他:0分     */     //模拟进行的场数    private static int DefaultSimTimes = 1000;    //模拟进行的步数    private static int simSteps = 200;    private static int cores = 4;    public static int getFitness(Individual ind) {        return getFitness(ind, DefaultSimTimes);    }    public static int getFitness(Individual ind, int simTimes) {        int fitness = 0;                MapMgr mgr = MapMgr.getInstance();          for (int i = 0; i < simTimes; i++) {            Map map = mgr.getMap(i);            Point point = map.getStartPoint();              for (int j = 0; j < simSteps; j++) {                State state = map.getState(point);                byte actionCode = ind.getActionCode(state);                fitness += action(point, map, actionCode);                //map.print();                //System.out.println("---");            }                                       }               return fitness / simTimes;    }    public static int getFitnessPall(Individual ind) {        int fitness = 0;                if (DefaultSimTimes < 100) {            fitness = getFitness(ind);        } else {                                        FutureTask<Integer>[] tasks = new FutureTask[cores];                        for (int i = 0; i < cores; i++) {                FitnessPall pall = null;                if (i == 0) {                    pall = new FitnessPall(ind, (DefaultSimTimes / cores) + DefaultSimTimes % cores);                } else {                    pall = new FitnessPall(ind, DefaultSimTimes / cores);                   }                               tasks[i] = new FutureTask<Integer>(pall);                Thread thread = new Thread(tasks[i]);                thread.start();            }                   for (int i = 0; i < cores; i++) {                try {                    fitness += tasks[i].get();                } catch (InterruptedException | ExecutionException e) {                    e.printStackTrace();                }            }            fitness = fitness / cores;        }        return fitness;    }    private static int action(Point point, Map map, int actionCode) {        int sorce = 0;        switch (actionCode) {        case 0:            if (map.isInMap(point.x, point.y - 1)) {                sorce = 0;                point.y = point.y - 1;            } else {                sorce = -5;            }                       break;        case 1:            if (map.isInMap(point.x - 1, point.y)) {                sorce = 0;                point.x = point.x - 1;            } else {                sorce = -5;            }            break;        case 2:            if (map.isInMap(point.x, point.y + 1)) {                sorce = 0;                point.y = point.y + 1;            } else {                sorce = -5;            }            break;        case 3:             if (map.isInMap(point.x + 1, point.y)) {                sorce = 0;                point.x = point.x + 1;            } else {                sorce = -5;            }            break;        case 4:            int randomCode = (int) Math.floor(Math.random() * 4);            sorce = action(point, map, randomCode);                     break;        case 5:            if (map.eatBean(point.x, point.y)) {                sorce = 10;                         } else {                sorce = -1;            }            break;        case 6:             sorce = 0;            break;        }        return sorce;    }}class FitnessPall implements Callable<Integer> {    private int simTimes;    private Individual ind;    public FitnessPall(Individual ind, int simTimes) {        this.ind = ind;        this.simTimes = simTimes;           }    @Override    public Integer call() throws Exception {        return FitnessCalc.getFitness(ind, simTimes);           }   }
0 0
原创粉丝点击