机器学习在线学习算法--迭代器实现

来源:互联网 发布:域名过户流程 编辑:程序博客网 时间:2024/06/04 23:24

在线学习时,我们不会一次性得到所有要训练的数据。数据会随时间而更新。对于这种情况,我们都是先训练已有数据,然后再训练不断得到的数据。类似的做法是,构建一个迭代器,用于每次训练一部分数据,直到所有数据都训练完。(真正的在线学习算法的数据是不会训练完的,会一直更新)

由于每次训练的数据不一样,会影响到分类器的准确度,也就是说,可能会影响到分类器的性能好坏。

import numpy as npimport matplotlib.pyplot as plt%matplotlib inlineimport pandas as pdfrom sklearn import cluster, datasetsfrom sklearn import metricsimport randomnp.random.seed(0)#传输进来的data,target是np.arraydef iter_minibatches(data,target,minibatch):    '''    迭代器    给定文件流(比如一个大文件),每次输出minibatch大小的数据    将输出转化成pandas输出,返回X, y    '''    X = []    y = []    cur_line_num = 0    if type(data)!=pd.core.frame.DataFrame:  #将输入转换为pandas        data=pd.DataFrame(data)        target=pd.DataFrame(target)    try:        b=list(data.index)        np.random.shuffle(b)        data=data.iloc[b,:]        target=target.iloc[b,:]    except:        print(b,data.shape,target.shape)    length=len(data)    start=0    while length-minibatch>0:        X=data.iloc[start:start+minibatch,:]        y=target.iloc[start:start+minibatch]        yield X,y        start=start+minibatch        length-=minibatch

调用方式

from sklearn.datasets import load_irisdata = load_iris()dataset=data['data']target=data['target']minibatch_train_iterators = iter_minibatches(dataset,target,10)for j, (X_train, y_train) in enumerate(minibatch_train_iterators):    print(X_train.shape,y_train.shape)
原创粉丝点击