鸢尾花分类算法实现 java

来源:互联网 发布:魏志香 知乎 编辑:程序博客网 时间:2024/05/02 06:10

使用的贝叶斯分类算法实现的,编程语言为java。是我本学期修的数据库与数据挖掘的课程的期末课程作业,算法本身不难,思路理清楚了很简单。

先看看鸢尾花(Iris)数据集(下图为数据集的部分截图),鸢尾花有setosa、Versicolor、Virginica3个类别,数据集中各个类别各50条数据,一共是150条数据记录,每条数据记录的前4个值分别表示鸢尾花的sepalLengthsepalWidth、petalLengthpetalWidth,第5个值是鸢尾花的类型。算法实现过程中将每个类别的前40条记录作为训练数据,进行分类模型的训练,每个类别的后10条数据作为测试数据,对分类模型的准确性进行判断。

算法基本思路:由概率论中先验概率后验概率的转换公式

可以得到:

C为鸢尾花类别,F1F2F3F4表示鸢尾花的4个特征,由于


可见,分母与类别没有关系,在分类时不提供判别信息,不作考虑。因此,分类只与分子有关,,又假设F1F2F3F4是相互独立,所以分子等价于

由于都有P(C)项,所以可以就只计算

即计算各特征分别属于各类别的概率,然后相乘,值最大的则是那个相应的类别。

所以鸢尾花分类的思路为:先分别计算3个类别分别在4个特征上的均值和方差,然后构造相应的概率密度函数,在对测试数据集进行分类的时候,即求4个特征值分别属于类别1的概率然后相乘,属于类别2概率然后相乘,属于类别3的概率然后相乘,再比较这3个概率值的大小,哪个最大则相应的分类哪个类别。


代码涉及到鸢尾花类以及鸢尾花的类型类(鸢尾花类型定义枚举类型),数据的读取类(读取鸢尾花数据集),鸢尾花分类类和概率值的计算类,这里给出鸢尾花分类类的代码:

BayesClassify.java

package com.test;import java.io.IOException;import java.util.ArrayList;import java.util.List;import com.test.DataReader;import com.test.Iris;public class BayesClassify {static List<Iris> irisDataSet; //鸢尾花数据集static List<Iris> testDataSet;//测试数据集static List<Iris> trainingDataSet;//训练数据集public BayesClassify() {irisDataSet = new ArrayList<Iris>();testDataSet = new ArrayList<Iris>();trainingDataSet = new ArrayList<Iris>();}public static void main(String[] args) throws IOException {DataReader reader = new DataReader();BayesClassify bayes = new BayesClassify();irisDataSet = reader.getIrisData();//读取鸢尾花数据集                bayes.prepareTrainingData();//准备训练数据                Calculate calcu = new Calculate();                calcu.CalMV(trainingDataSet);//计算均值、方差                bayes.prepareTestData();//准备测试数据                int n = 0; //分类正确的个数                for(Iris i : testDataSet) {              String type = calcu.CalP(i.getSepalLength(), i.getSepalWidth(), i.getPetalLength(), i.getPetalWidth());//获得分类  System.out.println("原本类别:"+i.getType().getLabel()+"---->最终分类为:"+type);              if(type.equals(i.getType().getLabel())) {        n++;        }        }                //分类的准确率                System.out.println("分类的正确率:"+(double)n/30);}public void prepareTrainingData() {if (irisDataSet.size() == 150) {trainingDataSet.addAll(irisDataSet.subList(0,40));//将Setosa的前40条加入训练集trainingDataSet.addAll(irisDataSet.subList(50,90));//将Versicolor的前40条加入训练集trainingDataSet.addAll(irisDataSet.subList(100,140));//将Virginica的前40条加入训练集}}public void prepareTestData() {if (irisDataSet.size() == 150) {testDataSet.addAll(irisDataSet.subList(40,50));//将Setosa的后10条加入测试集testDataSet.addAll(irisDataSet.subList(90,100));//将Versicolor的后10条加入测试集testDataSet.addAll(irisDataSet.subList(140,150));//将Virginica的后10条加入测试集}}}


运行结果:

从图上可以看出分类算法的准确率为76.7%

1 0
原创粉丝点击