Tensorflow之Basic word2vec代码详解(上)

来源:互联网 发布:javascript字符串替换 编辑:程序博客网 时间:2024/06/09 18:33

Tensorflow上关于Vector Representations of Words里给出了word2vec两个源代码,本文解析基础的代码,代码地址为:https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/tutorials/word2vec/word2vec_basic.py

上篇为代码step1-3:讲解数据下载处理,与训练数据的生成。

from __future__ import absolute_import  from __future__ import division  from __future__ import print_function    import collections  import math  import os  import random  import zipfile    import numpy as np  from six.moves import urllib  from six.moves import xrange  # pylint: disable=redefined-builtin  import tensorflow as tf    # Step 1: 下载数据  url = 'http://mattmahoney.net/dc/'      def maybe_download(filename, expected_bytes):    """Download a file if not present, and make sure it's the right size."""    if not os.path.exists(filename):      filename, _ = urllib.request.urlretrieve(url + filename, filename)    statinfo = os.stat(filename)    if statinfo.st_size == expected_bytes:      print('Found and verified', filename)    else:      print(statinfo.st_size)      raise Exception(          'Failed to verify ' + filename + '. Can you get to it with a browser?')    return filename  #  filename = maybe_download('text8.zip', 31344016)      # 解压缩并读取数据转化到数组中.  def read_data(filename):    """Extract the first file enclosed in a zip file as a list of words."""    with zipfile.ZipFile(filename) as f:      data = tf.compat.as_str(f.read(f.namelist()[0])).split() #split(分割成序列)   return data    vocabulary = read_data(filename)  print('Data size', len(vocabulary))  #建立字典  # Step 2: Build the dictionary and replace rare words with UNK token.  vocabulary_size = 50000  def build_dataset(words, n_words):    """Process raw inputs into a dataset."""    count = [['UNK', -1]]    count.extend(collections.Counter(words).most_common(n_words - 1))#计数,取词频前50000个词,其余的为unk,    dictionary = dict()    for word, _ in count:      dictionary[word] = len(dictionary)#高频词排序给编号    data = list()    unk_count = 0    for word in words:      if word in dictionary:        index = dictionary[word]#给高频词一个索引      else:        index = 0  # 低频词索引为0        unk_count += 1  #统计低频词的个数    data.append(index)    count[0][1] = unk_count    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys())) #逆词汇,键和值与dictionary相反   return data, count, dictionary, reversed_dictionary    data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,                                                              vocabulary_size)  del vocabulary  # Hint to reduce memory.  print('Most common words (+UNK)', count[:5])  print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])  #如vocabulary(daefbmc……)其中:1.a词频:600;2.b词频:500;3.c词频:400;4.d词频:300;5.e词频:200;6.f词频:100;……unk:4148#count([UNK,4148],[a,600],[b,500],[c,400],[d,300],[e,200,[f,100])#dictionary([a;1],[b:2],[c:3],[d:4],[e;5],[f;6])#data:{4,1,5,6,2,0,3}#reversed_dictionary([1;a],[2:b],[3:c],[4:d],[5;e],[6;f])#生成训练数据#从文本总体的第二次开始,每个单词一次作为输入,输出可以是上下文范围内的单词中的任何一个(一般不是取全部而是随机抽取其中几组,增加随机性)  data_index = 0   # Step 3: Function to generate a training batch for the skip-gram model.  def generate_batch(batch_size, num_skips, skip_window):  #batch_size:每次训练的词长度;num_skips:每个输入词重复的次数(一个输入产生多少个标签数据),skip_window:向左右取多少词  global data_index  #global:全局变量  assert batch_size % num_skips == 0  #assert:断言  assert num_skips <= 2 * skip_window      batch = np.ndarray(shape=(batch_size), dtype=np.int32)  #一维  labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)  #二维  span = 2 * skip_window + 1  # 左右两边各取skip_window,一共span个   buffer = collections.deque(maxlen=span)  #防止超长,挤出前面的数据,确保span个训练数据  if data_index + span > len(data):  #依次取span个词    data_index = 0    buffer.extend(data[data_index:data_index + span])    data_index += span    for i in range(batch_size // num_skips):      target = skip_window  #butter[skip_window]为输入数据    targets_to_avoid = [skip_window] #去除输入词自己本身     for j in range(num_skips):  #输入词重复num_skips次      while target in targets_to_avoid:          target = random.randint(0, span - 1) #随机生成 (0, span - 1)之间整数      targets_to_avoid.append(target)        batch[i * num_skips + j] = buffer[skip_window] #训练输入的序列       labels[i * num_skips + j, 0] = buffer[target]  #训练输出的序列(标签,对应词频的排序)    if data_index == len(data):  #超长时回到开始      buffer[:] = data[:span]        data_index = span      else:        buffer.append(data[data_index]) #挤掉开始几个,换一组词训练       data_index += 1    # Backtrack a little bit to avoid skipping words in the end of a batch    data_index = (data_index + len(data) - span) % len(data)    return batch, labels    batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)  for i in range(8):    print(batch[i], reverse_dictionary[batch[i]],          '->', labels[i, 0], reverse_dictionary[labels[i, 0]])   #如:vocabulary(m|daefbm|c……)取batch_size=6,num_skips=2,skip_window=1#batch = [4,4,1,1,5,5]#labels= [0,1,4,5,1,6]



原创粉丝点击