sklearn官网学习入门一

来源:互联网 发布:中老年女式毛衣淘宝 编辑:程序博客网 时间:2024/05/20 23:58

1.导入sklearn内置的数据集iris

from sklearn import datasetsiris = datasets.load_iris()

2.导入、初始化svm分类器

from sklearn import svmclf = svm.SVC()  # classifier的缩写,这里缺省了SVC()的参数

3.训练分类器( fit 函数)

X, y = iris.data, iris.target  #iris.data是由n_samples, n_features组成的2D-array,iris.target是sample的label组成的1D arrayclf.fit(X, y) # fit的两个参数,一个是由特征值组成的矩阵,一个是实际label

4.将训练所得的分类器保存下来,也称model persistence (模型持久化)

from sklearn.externals import joblibjoblib.dump(clf,'iris.pkl') #将所学习到的分类模型,保存在iris.pkl中

5.用训练得到的模型,预测未知标签的数据

clf = joblib.load("iris.pkl") #加载训练模型(分类器),并赋给clfclf.predict(X[0:1])  #此处一定要注意,X[0:1],与X[0]不一样,X[0:1]是2D array,X[0]是1D array
原创粉丝点击