Keras-3 Keras With Otto Group
来源:互联网 发布:抹茶美妆软件 编辑:程序博客网 时间:2024/05/16 18:59
Otto 分类问题
这里,我们将对Otto数据集进行分类。
- 本文主要参考 2.3 Introduction to Keras。个人觉得这是一个很好Keras教程,希望大家也去学习学习。
- 关于Otto,可以在 otto group 找到更多详细的材料
- 本文主要关注代码的实现,具体细节和基本概念不会详细展开
让我们开始吧
就像以前说过的那样,处理一个问题主要分为三个部分:数据准备,模型构建和模型优化
导入模块
这里遇到了新的模块
- StandardScaler 用于归一化,感觉很好使。详见StandardScaler
- LabelEncoder 配合np_utils用于One-hot编码,详见LabelEncoder。注意和OneHotEncoder的区别。
- EarlyStopping 当监测值不再改善时,该回调函数将中止训练。详见EarlyStopping
- ModelCheckpoint 保存模型。详见ModelCheckpoint
import numpy as npimport pandas as pdfrom sklearn.preprocessing import StandardScalerfrom sklearn.preprocessing import LabelEncoderfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers.core import Dense, Activation, Dropoutfrom keras.callbacks import EarlyStopping, ModelCheckpoint
## 数据准备读取数据。数据可以在 [otto group](https://www.kaggle.com/c/otto-group-product-classification-challenge/data) 找到train_path = './data/train.csv'test_path = './data/test.csv'df = pd.read_csv(train_path)
观察数据。有93个特征,最后一列是种类,第一列的id对于训练没有任何作用。df.head()
5 rows × 95 columns
导入数据。
- 第一列id对训练没用,所以我们不需要它。
- train 和 test 两个文件有所区别(test中没有给出target)
def load_data(path, train=True): df = pd.read_csv(path) X = df.values.copy() if train: np.random.shuffle(X) X, label = X[:, 1:-1].astype(np.float32), X[:, -1] return X, label else: X, ids = X[:, 1:].astype(np.float32), X[:, 0].astype(str) return X, ids
X_train, y_train = load_data(train_path)X_test, ids = load_data(test_path, train=False)
预处理,训练数据和测试数据一起归一化,以免忘记了
def preprocess_data(X, scaler=None): if not scaler: scaler = StandardScaler() scaler.fit(X) X = scaler.transform(X) return X, scaler
X_train, scaler = preprocess_data(X_train)X_test, _ = preprocess_data(X_test, scaler)
One-hot 编码
def preprocess_label(labels, encoder=None, categorical=True): if not encoder: encoder = LabelEncoder() encoder.fit(labels) y = encoder.transform(labels).astype(np.int32) if categorical: y = np_utils.to_categorical(y) return y, encoder
y_train, encoder = preprocess_label(y_train)
搭建网络模型
dim = X_train.shape[1]print(dim, 'dims')print('Building model')nb_classes = y_train.shape[1]model = Sequential()model.add(Dense(256, input_shape=(dim, )))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(128))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(64))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(nb_classes))model.add(Activation('softmax'))
93 dimsBuilding model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
batch_size = 128epochs = 2
训练,同时保持最佳模型
fBestModel = 'best_model.h5'early_stop = EarlyStopping(monitor='val_acc', patience=5, verbose=1)best_model = ModelCheckpoint(fBestModel, verbose=0, save_best_only=True)model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, verbose=1, validation_split=0.1, callbacks=[best_model, early_stop])
Train on 55690 samples, validate on 6188 samplesEpoch 1/255690/55690 [==============================] - 2s 42us/step - loss: 0.5256 - acc: 0.7967 - val_loss: 0.5268 - val_acc: 0.7982Epoch 2/255690/55690 [==============================] - 2s 42us/step - loss: 0.5251 - acc: 0.7991 - val_loss: 0.5256 - val_acc: 0.8017<keras.callbacks.History at 0x13551adaef0>
预测并保存结果。将结果保存为Kaggle上要求的格式,然后提交了测试结果,得到了0.5左右的分数,据说大概前50%左右
prediction = model.predict(X_test)
num_pre = prediction.shape[0]columns = ['Class_'+str(post+1) for post in range(9)]df2 = pd.DataFrame({'id' : range(1,num_pre+1)})df3 = pd.DataFrame(prediction, columns=columns)df_pre = pd.concat([df2, df3], axis=1)
df_pre.to_csv('predition.csv', index=False)
阅读全文
0 0
- Keras-3 Keras With Otto Group
- keras
- keras
- keras
- Keras
- keras
- Keras
- keras
- Keras with R (MLP)
- Keras with R (CNN)
- Keras-4 mnist With CNN
- Keras-2 Keras Mnist
- 【Keras】Keras学习框架
- keras:3)Embedding层详解
- 3D CNN in Keras
- Keras 常见问题
- keras浅学
- keras中文翻译
- HDU2036
- windows主机与虚拟机Linux共享文件夹
- ReactNative安卓端的打包发布
- leetcode729. My Calendar I
- 简单电路
- Keras-3 Keras With Otto Group
- jenkins的具体搭建和使用—使用tomcat容器
- 【练习题】构造方法 编写Java程序,模拟简单的计算器。
- Linux中GDB调试
- 三种继承、多态-虚函数
- 分享一篇百度云续命的大法
- struts2的初学
- Debugging Malloc Lab: Detecting Memory-Related Errors解答
- re模块