python 利用库sklearn 中的 grid_search对svm 参数寻优(借鉴)

来源:互联网 发布:淘宝助理快速发布宝贝 编辑:程序博客网 时间:2024/06/09 16:50
import time  
from sklearn import metrics  
import numpy as np  
import pickle 
from sklearn.tree import DecisionTreeClassifier
 
def svm_cross_validation(val_x, val_y):  
    from sklearn.grid_search import GridSearchCV  
    from sklearn.svm import SVC  
    model = SVC(kernel='rbf', probability=True)  
    param_grid = {'C': [ 1e-1, 1, 10], 'gamma': [0.001]}  
    grid_search = GridSearchCV(model, param_grid, n_jobs = 1, verbose=1)  
    grid_search.fit(train_x, train_y)  
    best_parameters = grid_search.best_estimator_.get_params()  
    for para, val in best_parameters.items():  
        print(para, val)  
    model = SVC(kernel='rbf', C=best_parameters['C'], gamma=best_parameters['gamma'], probability=True)  
    model.fit(train_x, train_y)  
    return model  
  

def read_data(data_file):  
    import gzip  
    f = gzip.open(data_file, "rb")  
    train, val, test = pickle.load(f,encoding='latin1')  
    f.close()  
    train_x = train[0]  
    train_y = train[1] 
    val_x = val[0]
    val_y=val[1]
    test_x = test[0]  
    test_y = test[1]  

    return train_x, train_y, test_x, test_y, val_x, val_y


if __name__ == '__main__':  
    data_file = "D:/Users/咖啡豆/Anaconda/Sklearn/Mnist/mnist.pkl.gz"  
    thresh = 0.5  
    model_save_file = None  
    model_save = {}  
      
    test_classifiers = ['SVMCV']
    classifiers = {'SVMCV':svm_cross_validation
    }  
      
    print('reading training and testing data...')  
    train_x, train_y, test_x, test_y, val_x,val_y= read_data(data_file)  
    num_train, num_feat = train_x.shape  
    num_test, num_feat = test_x.shape  
    is_binary_class = (len(np.unique(train_y)) == 2)  
    print('******************** Data Info *********************')  
    print('#training data: %d, #testing_data: %d, dimension: %d' % (num_train, num_test, num_feat))  
      
    for classifier in test_classifiers:  
        print('******************* %s ********************' % classifier)   
        start_time = time.time()  
        model = classifiers[classifier](train_x, train_y)  
        print ('training took %fs!' % (time.time() - start_time))  
        predict = model.predict(test_x)  
        if model_save_file != None:  
            model_save[classifier] = model  
        if is_binary_class:  
            precision = metrics.precision_score(test_y, predict)  
            recall = metrics.recall_score(test_y, predict)  
            print('precision: %.2f%%, recall: %.2f%%' % (100 * precision, 100 * recall))   
        accuracy = metrics.accuracy_score(test_y, predict)  
        print('accuracy: %.2f%%' % (100 * accuracy))   
  
    if model_save_file != None:  
        pickle.dump(model_save, open(model_save_file, 'wb'))













原创粉丝点击