词向量源码解析:(5.7)ngram2vec源码解析之counts2shuf等

来源:互联网 发布:ask软件 编辑:程序博客网 时间:2024/06/05 08:08

GloVe模型需要的是打乱的counts(共现矩阵)。counts2shuf会把counts打乱。此外,目前的GloVe模型接受的是二进制格式的输入。counts2bin的功能就是把python生成的文本格式的共现矩阵变成C语言可以直接读的二进制形式。首先看看counts2shuf。这个和GloVe中的是一模一样的,用python重写了一遍。shuffle的过程分成两个阶段。第一阶段是局部打乱。

    #shuffle round 1
    memory_size = float(args['--memory_size']) * 1000**3//把三元组尽可能多的读入内存然后打乱
    counts = []
    counts_num_per_file = []
    tmp_id = 0
    with open(args['<counts>'], 'r') as f:
        counts_num = 0
        print str(counts_num/1000**2) + "M counts processed."
        for line in f://对counts文件中的所有三元组进行循环
            print "\x1b[1A" + str(counts_num/1000**2) + "M counts processed."
            counts_num += 1
            word, context, count = line.strip().split()
            counts.append((int(word), int(context), float(count)))
            if getsizeof(counts) + (getsizeof((int(0),int(0),float(0))) + getsizeof(int(0)) * 2 + getsizeof(float(0)) ) * len(counts) > memory_size://内存大小等于列表数据结构占用内存大小加上元组数据结构占用内存大小加上元组中的内容占用的内存大小
                random.shuffle(counts)//打乱列表中的元组
                with open(args['<output>'] + str(tmp_id), 'w') as f://写出
                    for count in counts:
                        f.write(str(count[0]) + ' ' + str(count[1]) + ' ' + str(count[2]) + '\n')
                counts_num_per_file.append(counts_num)//记录一下写出的个数
                counts = []
                tmp_id += 1


    random.shuffle(counts)//把最后剩在内存中的三元组打乱写出
    with open(args['<output>'] + str(tmp_id), 'w') as f:
        for count in counts:
            f.write(str(count[0]) + ' ' + str(count[1]) + ' ' + str(count[2]) + '\n')
        counts = []
        tmp_id += 1


    print "number of tmpfiles: ", tmp_id 

第二阶段,对局部打乱的临时文件中的三元组再一次打乱。从每个文件中读取一部分数据,打乱,最后写出到总的打乱的counts文件中。

    counts_num = 0
    output_file = open(args['<output>'], 'w')
    tmpfiles = []
    for i in xrange(tmp_id)://打开临时文件
        tmpfiles.append(open(args['<output>'] + str(i), 'r'))
    
    tmp_num = counts_num_per_file[0] / tmp_id
    print str(counts_num/1000**2) + "M counts processed."
    for i in xrange(tmp_id - 1)://大概需要从临时文件中读取tmp_id次,也就是在内存中读入打乱写出tmp-1次
        counts = []
        for f in tmpfiles://遍历每个文件
            for j in xrange(tmp_num)://从每个文件中读入一定量的三元组
                line = f.readline()
                if len(line) > 0:
                    print "\x1b[1A" + str(counts_num/1000**2) + "M counts processed."
                    counts_num += 1
                    word, context, count = line.strip().split()
                    counts.append((int(word), int(context), float(count)))
        random.shuffle(counts)//打乱内存中的三元组
        for count in counts://写出
            output_file.write(str(count[0]) + ' ' + str(count[1]) + ' ' + str(count[2]) + '\n')
    counts = []
    for f in tmpfiles://每个临时文件中的三元组数量不一定都一样,所以最后把剩在临时文件中的三元组都读入内存打乱然后写出到最后的打乱的counts中
        for line in f:
            print "\x1b[1A" + str(counts_num/1000**2) + "M counts processed."
            counts_num += 1
            word, context, count = line.strip().split()
            counts.append((int(word), int(context), float(count)))
    random.shuffle(counts)//打乱
    for count in counts://写出
        output_file.write(str(count[0]) + ' ' + str(count[1]) + ' ' + str(count[2]) + '\n')


    for i in xrange(tmp_id):
        tmpfiles[i].close()
    for i in xrange(tmp_id)://删除临时文件
        os.remove(args['<output>'] + str(i))
    output_file.close()    
    print "number of counts: ", counts_num

下面看一下counts2bin,代码就几行,把共现矩阵写成C语言可读的二进制形式(字节流)。glovef和GloVe都是C语言版本的,接受二进制的共现矩阵

def main():
    args = docopt("""
    Usage:
        counts2bin.py <counts> <output>


    """)
    
    print "**********************"
    print "counts2bin"


    bin_file = open(args['<output>'], 'wb')
    with open(args['<counts>'], 'r') as f:
        counts_num = 0
        print str(counts_num/1000**2) + "M tokens processed."
        for line in f:
            print "\x1b[1A" + str(counts_num/1000**2) + "M tokens processed."
            counts_num += 1
            word, context, count = line.strip().split()
            b = struct.pack('iid', int(word), int(context), float(count))//把三元组打包成字节流,写到文件,glovef(GloVe)可以直接读取
            bin_file.write(b)
    print "number of counts: " + str(counts_num)
    bin_file.close()

阅读全文
0 0
原创粉丝点击