logistic regression 处理鸢尾花数据集

来源:互联网 发布:三菱plc3u编程手册 编辑:程序博客网 时间:2024/04/20 20:26

logistic regression 处理鸢尾花数据集

——————————————————————keras实现

鸢尾花数据集是一个比较经典的多重变量分类数据集。它最初是埃德加·安德森从加拿大加斯帕半岛上的鸢尾属花朵中提取的地理变异数据,后由罗纳德·费雪作为判别分析的一个例子,运用到统计学中。

鸢尾花数据集所代表的分类问题可以较好的用逻辑斯蒂回归(logistic regression)解决,不过鸢尾花数据集是多分类问题,需要应用logistic regression在多分类问题上的推广:softmax

// to do:softmax问题的理论分析

keras是一个非常好用的high-level的神经网络框架,针对多分类问题,可以用keras构造神经网络实现softmax分类

网络结构:

代码:

import numpy as npfrom tensorflow.contrib.keras.api.keras.models import Sequentialfrom tensorflow.contrib.keras.api.keras.layers import Dense, Activationdef load_data(file_name):    train = np.loadtxt(file_name)    train_x = train[:, :4]    y_ = train[:, 4]    y_train = np.ndarray([len(y_),3])    for i in range(0,len(y_)):        arr = np.array([0,0,0])        arr[(int)(y_[i])] = 1        y_train[i] = arr    return train_x, y_trainmodel = Sequential()model.add(Dense(units=4, input_dim=4))model.add(Activation(activation="linear"))model.add(Dense(units=3))model.add(Activation("softmax"))model.compile(loss="categorical_crossentropy",              optimizer='sgd',              metrics=['accuracy'])x_train, y_train = load_data("flower_train")x_test, y_test = load_data("flower_test")model.fit(x_train, y_train, epochs=20, batch_size=5)loss_and_metrics = model.evaluate(x_test, y_test, batch_size=15)print(loss_and_metrics)correct_count = 0for i in range(0, len(x_test)):    result = model.predict(np.array([x_test[i]]))    if np.argmax(result[0]) == np.argmax(y_test[i]):        correct_count+=1print(correct_count)print(correct_count/len(x_test))

代码对应的数据文件:

原创粉丝点击