xgboost和lightgbm学习

来源:互联网 发布:海关信息进出口数据 编辑:程序博客网 时间:2024/05/17 09:13

demo basic_walkthrough.py
加载输入

dtrain = xgb.DMatrix('../data/agaricus.txt.train')dtest = xgb.DMatrix('../data/agaricus.txt.test')## 其中dtest就包括了lable和数据,lable可以通过labels = dtest.get_label()获取。 dtrain = xgb.DMatrix(csr, label = labels)# csr可以是'scipy.sparse.csr.csr_matrix‘# labels可以是list类型csc = scipy.sparse.csc_matrix((dat, (row,col)))或csc矩阵

参数设置,具体参数什么意思参考传送门

param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }

训练,直接导入参数,训练数据,训练轮数和watchlist

watchlist  = [(dtest,'eval'), (dtrain,'train')] # 'eval'是watchlist的名字num_round = 2bst = xgb.train(param, dtrain, num_round, watchlist)

预测

preds = bst.predict(dtest) # 输出概率是[sample,1],每个样本输出一个概率值

保存模型和保存数据

# save dmatrix into binary bufferdtest.save_binary('dtest.buffer')# save modelbst.save_model('xgb.model')# load model and data inbst2 = xgb.Booster(model_file='xgb.model')dtest2 = xgb.DMatrix('dtest.buffer')preds2 = bst2.predict(dtest2)

binary:logistic 就是错误等误差eval_metric[默认值取决于objective参数的取值]

原创粉丝点击