『sklearn学习』不同的 SVM 分类器

来源:互联网 发布:美工工资高吗 编辑:程序博客网 时间:2024/05/02 04:52
#! usr/bin/env python# coding:utf-8"""__author__ = "LCG22"__date__ = "2016-12-5""""import numpy as npimport matplotlib.pyplot as pltfrom sklearn import svm, datasetsiris = datasets.load_iris()X = iris.data[:, :2]y = iris.targeth = 0.02C = 1.0svc = svm.SVC(kernel="linear", C=C).fit(X, y)rbf_svc = svm.SVC(kernel="rbf", gamma=0.7, C=C).fit(X, y)poly_svc = svm.SVC(kernel="poly", degree=3, C=C).fit(X, y)lin_svc = svm.LinearSVC(C=C).fit(X, y)X_min, X_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(X_min, X_max, h),                     np.arange(y_min, y_max, h))titles = ['SVC with linear kernel',          'LinearSVC(linear kernel)',          'SVC with RBF kernel',          'SVC with polynomial(degree 3) kernel']for i, clf in enumerate((svc, lin_svc, rbf_svc, poly_svc)):    plt.subplot(2, 2, i+1)    plt.subplots_adjust(wspace=0.4, hspace=0.4)    test_x = np.c_[xx.ravel(), yy.ravel()]    Z = clf.predict(test_x)    Z = Z.reshape(xx.shape)    plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm)    plt.xlabel("Sepal length")    plt.ylabel("Sepal width")    plt.xlim(xx.min(), xx.max())    plt.ylim(yy.min(), yy.max())    plt.xticks(())    plt.yticks(())    plt.title(titles[i])plt.show()

1 0
原创粉丝点击