MiniBatchKMeans简单应用

来源:互联网 发布:网络攻防渗透 编辑:程序博客网 时间:2024/06/06 05:03

MiniBatchKMeans比KMeans快很多,效果也不错,应用于文本聚类如下:

#!/usr/bin/env python# -*- coding: utf-8 -*-from __future__ import print_functionimport loggingimport osimport refrom collections import defaultdictfrom time import timeimport jiebafrom gensim.utils import to_utf8from six.moves import xrangefrom sklearn.cluster import MiniBatchKMeansfrom sklearn.feature_extraction.text import TfidfVectorizerlogging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO)logger = logging.getLogger(__name__)def load_stopwords():    # path = '/Users/fhqplzj/github/HanLP/data/dictionary/stopwords.txt'    path = '/data/zhaojun/local_projects/stopwords.txt'    return frozenset(open(path, 'rb').read().decode('utf-8').splitlines())# 停用词stopwords = load_stopwords()chinese = re.compile(ur'^[\u4e00-\u9fa5]+$')def chinese_non_stopwords(word):    # 全是中文,并且不是停用词    result = True if re.match(chinese, word) else False    return result and word not in stopwordsdef sentence_tokenizer(sentence):    # 分词,过滤    try:        content = sentence.strip().split('\t', 1)[1]    except IndexError:        content = u'呢'    return filter(chinese_non_stopwords, jieba.lcut(content))def load_documents(path):    # path = '/Users/fhqplzj/Downloads/part-' + path    path = '/data/zhaojun/part100/part-' + path    logger.info('processing file: %s' % path)    return open(path, 'rb').read().decode('utf-8').splitlines()if __name__ == '__main__':    file_names = map(lambda i: '{:05d}'.format(i), xrange(100))    docs = []    for file_name in file_names:        docs.extend(load_documents(file_name))    t0 = time()    logger.info('TfidfVectorizer...')    vectorizer = TfidfVectorizer(tokenizer=sentence_tokenizer, min_df=5, max_df=0.1)    X = vectorizer.fit_transform(docs)    logger.info('vectorizer: %fs' % (time() - t0))    t0 = time()    logger.info('MiniBatchKMeans...')    km = MiniBatchKMeans(n_clusters=100, batch_size=1000)    km.fit(X)    logger.info('kmeans: %fs' % (time() - t0))    t0 = time()    logger.info('collecting result')    pred_labels = km.labels_    result = defaultdict(list)    for idx in xrange(len(pred_labels)):        result[pred_labels[idx]].append(docs[idx])    for k in result:        name = 'res-{:05d}'.format(k)        elems = result[k]        out_path = os.path.join('/tmp/cluster', name)        with open(out_path, 'w') as fout:            logger.info('writing %s' % out_path)            for elem in elems:                fout.write(to_utf8(elem) + '\n')    logger.info('finished: %fs' % (time() - t0))    # sorted_indices = km.cluster_centers_.argsort()[:, ::-1]    # id2words = vectorizer.get_feature_names()    # for i in range(km.n_clusters):    #     print('cluster: %i' % i)    #     for idx in sorted_indices[i, :10]:    #         print(' %s' % id2words[idx])    #     print()


原创粉丝点击