opencv3.0 机器学习模块使用

来源:互联网 发布:lambda python 编辑:程序博客网 时间:2024/06/08 06:17

using namespace ml;

//随机树分类
Ptr<StatModel> lpmlBtnClassify::buildRtreesClassifier(Mat data, Mat responses, int ntrain_samples)
{

Ptr<RTrees> model;

Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);

model = RTrees::create();

model->setMaxDepth(10);

model->setMinSampleCount(10);

model->setRegressionAccuracy(0);

model->setUseSurrogates(false);

model->setMaxCategories(15);

model->setPriors(Mat());

model->setCalculateVarImportance(false);

model->setTermCriteria(setIterCondition(100, 0.01f));

model->train(tdata);

return model;

}

//adaboost分类
Ptr<StatModel> lpmlBtnClassify::buildAdaboostClassifier(Mat data, Mat responses, int ntrain_samples,int param0)
{

Mat weak_responses;

int i, j, k;

Ptr<Boost> model;

int nsamples_all = data.rows;

int var_count = data.cols;

Mat new_data(ntrain_samples*class_count, var_count + 1, CV_32F);

Mat new_responses(ntrain_samples*class_count, 1, CV_32S);

for (i = 0; i < ntrain_samples; i++)

{

const float* data_row = data.ptr<float>(i);

for (j = 0; j < class_count; j++)

{

float* new_data_row = (float*)new_data.ptr<float>(i*class_count + j);

memcpy(new_data_row, data_row, var_count*sizeof(data_row[0]));

new_data_row[var_count] = (float)j;

new_responses.at<int>(i*class_count + j) = responses.at<int>(i) == j;

}

}

Mat var_type(1, var_count + 2, CV_8U);

var_type.setTo(Scalar::all(VAR_ORDERED));

var_type.at<uchar>(var_count) = var_type.at<uchar>(var_count + 1) = VAR_CATEGORICAL;

Ptr<TrainData> tdata = TrainData::create(new_data, ROW_SAMPLE, new_responses,

noArray(), noArray(), noArray(), var_type);

model = Boost::create();

model->setBoostType(Boost::GENTLE);

model->setWeakCount(param0);

model->setWeightTrimRate(0.95);

model->setMaxDepth(5);

model->setUseSurrogates(false);

model->train(tdata);

return model;

}

//多层感知机分类(ANN)
Ptr<StatModel> lpmlBtnClassify::buildMlpClassifier(Mat data, Mat responses, int ntrain_samples)
{

Ptr<ANN_MLP> model;

Mat train_data = data.rowRange(0, ntrain_samples);

Mat train_responses = Mat::zeros(ntrain_samples, class_count, CV_32F);

// 1. unroll the responses

for (int i = 0; i < ntrain_samples; i++)

{

int cls_label = responses.at<int>(i);

train_responses.at<float>(i, cls_label) = 1.f;

}

// 2. train classifier
int layer_sz[] = { data.cols, 100, 100, class_count };
int nlayers = (int)(sizeof(layer_sz) / sizeof(layer_sz[0]));

Mat layer_sizes(1, nlayers, CV_32S, layer_sz);

#if 1

int method = ANN_MLP::BACKPROP;

double method_param = 0.001;

int max_iter = 300;

#else

int method = ANN_MLP::RPROP;

double method_param = 0.1;

int max_iter = 1000;

#endif

Ptr<TrainData> tdata = TrainData::create(train_data, ROW_SAMPLE, train_responses);

model = ANN_MLP::create();

model->setLayerSizes(layer_sizes);

model->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);

model->setTermCriteria(setIterCondition(max_iter, 0));

model->setTrainMethod(method, method_param);

model->train(tdata);

model->save("myANN.xml");

return model;

}

//调用训练好的ANN

Ptr<ANN_MLP> ann = Algorithm::load<ANN_MLP>("myANN.xml");

//其中 save 与 load 是相对应的,write 与read 相对应


//贝叶斯分类
Ptr<StatModel> lpmlBtnClassify::buildNbayesClassifier(Mat data, Mat responses, int ntrain_samples)
{

Ptr<NormalBayesClassifier> model;

Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);

model = NormalBayesClassifier::create();

model->train(tdata);

return model;

}

Ptr<StatModel> lpmlBtnClassify::buildKnnClassifier(Mat data, Mat responses, int ntrain_samples, int K)
{

Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);

Ptr<KNearest> model = KNearest::create();

model->setDefaultK(K);

model->setIsClassifier(true);

model->train(tdata);

return model;

}

//svm分类
Ptr<StatModel> lpmlBtnClassify::buildSvmClassifier(Mat data, Mat responses, int ntrain_samples)
{

Ptr<SVM> model;

Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);

model = SVM::create();

model->setType(SVM::C_SVC);

model->setKernel(SVM::RBF);

model->setC(1);

model->train(tdata);

return model;

}

0 0
原创粉丝点击