Python机器学习库sklearn里利用感知机进行三分类(多分类)的原理

来源:互联网 发布:rmvb转mp4 mac版 编辑:程序博客网 时间:2024/06/14 08:35

感知机的理论参考http://blog.csdn.net/cymy001/article/details/77992416

from IPython.display import Image  %matplotlib inline  # Added version check for recent scikit-learn 0.18 checks  from distutils.version import LooseVersion as Version  from sklearn import __version__ as sklearn_version    from sklearn import datasets  import numpy as np  iris = datasets.load_iris() #http://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html  X = iris.data[:, [2, 3]]  y = iris.target  #取species列,类别    if Version(sklearn_version) < '0.18':      from sklearn.cross_validation import train_test_split  else:      from sklearn.model_selection import train_test_split  X_train, X_test, y_train, y_test = train_test_split(      X, y, test_size=0.3, random_state=0)  #train_test_split方法分割数据集    from sklearn.preprocessing import StandardScaler  sc = StandardScaler()   #初始化一个对象sc去对数据集作变换  sc.fit(X_train)   #用对象去拟合数据集X_train,并且存下来拟合参数  X_train_std = sc.transform(X_train)  X_test_std = sc.transform(X_test)  from sklearn.linear_model import Perceptron#http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Perceptron.html#sklearn.linear_model.Perceptron#ppn = Perceptron(n_iter=40, eta0=0.1, random_state=0)ppn = Perceptron()  #y=w.x+bppn.fit(X_train_std, y_train)#验证perceptron的原理def prelabmax(X_test_std):    pym=[]    for i in range(X_test_std.shape[0]):        py=np.dot(ppn.coef_,X_test_std[i,:].T)+ppn.intercept_        pym.append(max(py))    return pymprelabmax(X_test_std)   def prelabindex(X_test_std,pym):    index=[]    for i in range(X_test_std.shape[0]):        py=np.dot(ppn.coef_,X_test_std[i,:].T)+ppn.intercept_        pymn=pym[i]        for j in range(3):            if py[j]==pymn:                index.append(j)    return np.array(index)pym=prelabmax(X_test_std)prelabindex(X_test_std,pym)prelabindex(X_test_std,pym)==ppn.predict(X_test_std)#Output:array([ True,  True,  True,  True,  True,  True,  True,  True,  True,#               True,  True,  True,  True,  True,  True,  True,  True,  True,#               True,  True,  True,  True,  True,  True,  True,  True,  True,#               True,  True,  True,  True,  True,  True,  True,  True,  True,#               True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)

即选择y=wx+b值最大的项所在的组为其类别


阅读全文
0 0
原创粉丝点击