逻辑回归实战 — Kaggle_Titanic
来源:互联网 发布:php优化方案 编辑:程序博客网 时间:2024/06/11 22:04
数据来源:https://www.kaggle.com/c/titanic
Training
import pandasimport numpyimport timeimport matplotlib.pyplot as plt%matplotlib inlinedef prepareData(filename): data = pandas.read_csv(filename) data['Sex'] = data['Sex'].map({'female':0, 'male':1}) data['Embarked'] = data['Embarked'].map({'S':1, 'C':2, 'Q':3}) pier = [0 if numpy.isnan(item) else item for item in data['Embarked']] data['Embarked'] = [max(set(pier), key=pier.count) if item == 0 else item for item in pier] age_avg = numpy.mean([0 if numpy.isnan(item) else item for item in data['Age']]) data['Age'] = [age_avg if numpy.isnan(item) else item for item in data['Age']] #data['Age'] = [1/(1+numpy.exp(-item)) for item in data['Age']] data['Age'] = [(item-min(data['Age']))/(max(data['Age'])-min(data['Age'])) for item in data['Age']] #data['Fare'] = [1/(1+numpy.exp(-item)) for item in data['Fare']] data['Fare'] = [(item-min(data['Fare']))/(max(data['Fare'])-min(data['Fare'])) for item in data['Fare']] #data = data.drop(['PassengerId','Name','Ticket','Cabin'], axis=1) data.insert(0, 'ones', 1) return data
def run(X, Y, theta, alpha, steps): init_time = time.time() costs = [getCost(X, Y, theta)] count = 0 with open('titanic/model.txt','w') as f: for i in range(len(theta)): f.write('theta_' + str(i) + ',') f.write('cost\n') while count < steps: theta -= alpha*getGradient(X, Y, theta) cost = getCost(X, Y, theta) costs.append(cost) for item in theta: f.write(str(item)+',') f.write(str(cost)+'\n') count += 1 time_spent = time.time()-init_time return costs, theta, time_spent
def getGradient(X, Y, theta): gradient = numpy.zeros(len(theta)) for j in range(len(theta)): tmp = 0 for x,y in zip(X,Y): tmp += x[j]*(y - 1/(1+numpy.exp(-numpy.dot(theta,x)))) gradient[j] = -1/len(Y)*tmp return gradient
def getCost(X, Y, theta): cost = 0 for x,y in zip(X,Y): cost += -numpy.log(numpy.exp(numpy.dot(theta,x)) + 1) + y*numpy.dot(theta,x) return -cost/len(Y)
def getAccuracy(train_X, train_Y, theta): Y_hat = [] for x in train_X: y_hat = 1/(1+numpy.exp(-numpy.dot(theta, x))) if y_hat >= 0.5: Y_hat.append(1) else: Y_hat.append(0) correct = 0.0 for i,j in zip(Y_hat, train_Y): if i == j: correct += 1 accuracy = correct/len(Y_hat) return accuracy
train_data = prepareData('titanic/train.csv')train_data.head(5)
train_Y = (train_data.drop(['PassengerId','Name','Ticket','Cabin'], axis=1))['Survived'].valuestrain_X = train_data.drop(['PassengerId','Name','Ticket','Cabin'], axis=1).drop(['Survived'], axis=1).valuestheta = numpy.random.random(len(train_X[1]))alpha = 0.001steps = 10000costs, theta, time_spent = run(train_X, train_Y, theta, alpha, steps)accuracy = getAccuracy(train_X, train_Y, theta)fig = plt.figure(figsize=(18,5))ax1 = fig.add_subplot(121)ax1.plot(range(steps+1), costs)ax1.set_title('Logistic Regression for Titanic Problem -- Time spent: %f\nAccuracy: %f' % (time_spent, accuracy))ax1.set_xlabel('steps')ax1.set_ylabel('cost')ax2 = fig.add_subplot(122)ax2.plot(range(steps+1)[-1000:-1], costs[-1000:-1])ax2.set_xlabel('steps')ax2.set_ylabel('cost')
Testing
test_data = prepareData('titanic/test.csv')test_data.head(5)
test_X = test_data.drop(['PassengerId','Name','Ticket','Cabin'], axis=1).valuesY_hat = []for x in test_X: y_hat = 1/(1+numpy.exp(-numpy.dot(theta, x))) if y_hat >= 0.5: Y_hat.append(1) else: Y_hat.append(0)
results = pandas.DataFrame(Y_hat, columns=['Survived'])results.insert(0, 'PassengerId', test_data['PassengerId'])results.to_csv('titanic/results.csv')
阅读全文
0 0
- 逻辑回归实战 — Kaggle_Titanic
- 逻辑回归实战 — Kaggle_Titanic 2
- 机器学习实战—逻辑回归
- 机器学习实战-逻辑回归
- 线性回归与逻辑回归实战
- 机器学习实战(四)——logisticRegression逻辑回归
- PYTHON机器学习实战——逻辑回归
- 机器学习理论与实战:逻辑回归
- Python机器学习实战之逻辑回归
- 机器学习实战【4】(逻辑回归)
- python之实战----逻辑回归战iris
- MXnet代码实战之多类逻辑回归
- 机器学习实战(5)逻辑回归
- 逻辑回归 — Logistic Regression
- 机器学习理论与实战(四)逻辑回归
- 机器学习理论与实战(四)逻辑回归
- 机器学习实战逻辑回归的java实现
- 机器学习实战3:逻辑logistic回归:病马实例
- 人工智能入门
- 实验6:图的实验1——图的邻接矩阵存储实现
- php-app开发接口加密范例
- 绝对中位差Median Absolute Deviation
- 提高你开发效率的十五个 Visual Studio 使用技巧
- 逻辑回归实战 — Kaggle_Titanic
- mysql的sql注入介绍
- 下载Eclipse,以及安装
- 浅谈C++类
- eclipse :SVN E175002报错解决
- 进程间通信(IPC)
- 深入了解java虚拟机!
- java面试题全集(下)
- 建议3:三元操作符的类型务必一致