tf.nn.embedding_lookup中关于partition_strategy参数详解

来源:互联网 发布:网络随身听收音机 编辑:程序博客网 时间:2024/04/28 03:51

  • tfnnembedding_lookup
    • 数学上的原理
    • API介绍
    • 简单示例
      • 程序
      • 注解
    • partition_strategy参数的示例
      • mod案例1
      • mod案例2
      • div案例1
      • div案例2
  • 参考资料

tf.nn.embedding_lookup

embedding_lookup常用于NLP中将one-hot编码转换我对应的向量编码。

数学上的原理

数学上的原理

假设一共有m个物体,每个物体有自己唯一的id,那么从物体的集合到Rm有一个trivial的嵌入,就是把它映射到Rm中的标准基,这种嵌入叫做One-hot embedding/encoding.

应用中一般将物体嵌入到一个低维空间Rn(nm),只需要再compose上一个从RmRn的线性映射就好了。每一个n×m的矩阵M都定义了RmRn的一个线性映射: xMx。当x是一个标准基向量的时候,Mx对应矩阵M中的一列,这就是对应id的向量表示。这个概念用神经网络图来表示如下:

这里写图片描述

从id(索引)找到对应的One-hot encoding,然后红色的weight就直接对应了输出节点的值(注意这里没有activation function),也就是对应的embedding向量。


API介绍

API介绍

依据inputs_ids来寻找embedding_params中对应的元素.

 embedding_lookup(     params,   # embedding_params 对应的转换向量     ids,      # inputs_ids,标记着要查询的id     partition_strategy='mod',   #分割方式      name=None,     validate_indices=True, # deprecated     max_norm=None )
参数 description 注解 params A single tensor representing the complete embedding tensor, or a list of P tensors all of same shape except for the first dimension, representing sharded embedding tensors. Alternatively, a PartitionedVariable, created by partitioning along dimension 0. Each element must be appropriately sized for the given partition_strategy. params是由一个tensor或者多个tensor组成的列表(多个tensor组成时,每个tensor除了第一个维度其他维度需相等) ids A Tensor with type int32 or int64 containing the ids to be looked up in params. ids是一个整型的tensor,ids的每个元素代表要在params中取的每个元素的第0维的逻辑index. partition_strategy A string specifying the partitioning strategy, relevant if len(params) > 1. Currently “div” and “mod” are supported. Default is “mod”. 逻辑index是由partition_strategy指定,partition_strategy用来设定ids的切分方式,目前有两种切分方式’div’和’mod’. 返回值 The results of the lookup are concatenated into a dense tensor. The returned tensor has shape shape(ids) + shape(params)[1:]. 返回值是一个dense tensor.返回的shape为shape(ids)+shape(params)[1:]

embedding_lookup中的partition_strategy参数比较难理解(this function is hard to understand, until you get the point!),下面会有特别的解释。


简单示例

简单示例

下面我们通过一个常见的案例来解释embedding_lookup的用法:

程序

# coding:utf8import tensorflow as tfimport numpy as npinput_ids = tf.placeholder(dtype=tf.int32, shape=[None])_input_ids = tf.placeholder(dtype=tf.int32, shape=[3, 2])embedding_param = tf.Variable(np.identity(8, dtype=np.int32))   # 生成一个8x8的单位矩阵input_embedding = tf.nn.embedding_lookup(embedding_param, input_ids)_input_embedding = tf.nn.embedding_lookup(embedding_param, _input_ids)sess = tf.InteractiveSession()sess.run(tf.global_variables_initializer())print('embedding:')print(embedding_param.eval())var1 = [1, 2, 6, 4, 2, 5, 7]print('\n var1:')print(var1)print('\nprojecting result:')print(sess.run(input_embedding, feed_dict={input_ids: var1}))var2 = [[1, 4], [6, 3], [2, 5]]print('\n _var2:')print(var2)print('\n _projecting result:')print(sess.run(_input_embedding, feed_dict={_input_ids: var2}))'''输出:embedding:[[1 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0] [0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]] var1:[1, 2, 6, 4, 2, 5, 7]projecting result:[[0 1 0 0 0 0 0 0] [0 0 1 0 0 0 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 1 0 0 0] [0 0 1 0 0 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 1]] _var2:[[1, 4], [6, 3], [2, 5]] _projecting result:[[[0 1 0 0 0 0 0 0]  [0 0 0 0 1 0 0 0]] [[0 0 0 0 0 0 1 0]  [0 0 0 1 0 0 0 0]] [[0 0 1 0 0 0 0 0]  [0 0 0 0 0 1 0 0]]]'''

注解

  • embedding_param参数是一个8*8的单位矩阵(这个这是由一个tensor构成的params,即len(params)=1,partition_strategy只在len(params)>1时才作用)。
embedding_param=          # embedding_param只由一个tensor组成  故len(embedding_param) = 1[[1 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0] [0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]]
  • 我们ids为var1,照着此id从embedding_param取对应的行元素.
    var1 = [1, 2, 6, 4, 2, 5, 7]    # 1即取第2行  --> [0 1 0 0 0 0 0 0]    # 2即取第3行  --> [0 0 1 0 0 0 0 0]    # etc.
  • 我们ids为var2,照着此id从embedding_param取对应的行元素
    var2 = [[1, 4], [6, 3], [2, 5]]    '''    [1, 4] 即取2,5行     [[0 1 0 0 0 0 0 0]     [0 0 0 0 1 0 0 0]]    后面同理    ''' 

partition_strategy参数的示例

关于partition_strategy参数的示例 api描述 注解 If len(params) > 1, each element id of ids is partitioned between the elements of params according to the partition_strategy. In all strategies, if the id space does not evenly divide the number of partitions, each of the first (max_id + 1) % len(params) partitions will be assigned one more id. 如果len(params) > 1,params的元素分割方式是依据partition_strategy的。如果分段不能整分的话,则前(max_id + 1) % len(params)多分一个id. If partition_strategy is “mod”, we assign each id to partition p = id % len(params). For instance, 13 ids are split across 5 partitions as: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]] 例如,如果partition_strategy =’mod’.如果我们的params是由5个tensor组成,他们的第一个维度相加为13,则分割策略为[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]] If partition_strategy is “div”, we assign ids to partitions in a contiguous manner. In this case, 13 ids are split across 5 partitions as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]] 例如,如果partition_strategy =’div’.如果我们的params是由5个tensor组成,他们的第一个维度相加为13,则分割策略为[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]

看api迷迷糊糊的,就看下面的四个例子,就会明白这个函数的操作方法了~


‘mod’案例1

   # coding:utf8   import tensorflow as tf   import numpy as np   def test_embedding_lookup():       a = np.arange(12).reshape(3, 4)       b = np.arange(12, 16).reshape(1, 4)       c = np.arange(16, 28).reshape(3, 4)       print(a)       print('\n')       print(b)       print('\n')       print(c)       print('\n')       a = tf.Variable(a)       b = tf.Variable(b)       c = tf.Variable(c)       t = tf.nn.embedding_lookup([a, b, c],           partition_strategy='mod', ids=[0, 3, 6, 1, 2, 5, 8])       init = tf.global_variables_initializer()       sess = tf.Session()       sess.run(init)       m = sess.run(t)       print(m)   test_embedding_lookup()
        '''        分析:         这里我们注意到params是由[a, b, c]这三个tensor组成。即len(params)=3,且a,b,c这三个tensor的第一维度分别为3,1,3。         在把这个三个tensor组合过程中,我们按照partition_strategy='mod'策略分割。即每个tensor的元素之间相差len(params).这里分割方式为[a, b, c]  == [[0,3,6], [1,4,7], [2,5,8]]           这里程序还不知道4和7是找不到对应的元素的,在获取元素时候会报错        a=[[ 0  1  2  3]     = [0, 3, 6]  -->  [0  1  2  3]  = 0           [ 4  5  6  7]                  -->  [4  5  6  7]  = 3           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 6        b=[[12 13 14 15]]    = [1, 4, 7]  -->  [12 13 14 15] = 1                                          -->  运行时报错  = 4                                          -->  运行时报错  = 7        c = etc..              输出:        [[ 0  1  2  3]         [ 4  5  6  7]         [ 8  9 10 11]]        [[12 13 14 15]]        [[16 17 18 19]         [20 21 22 23]         [24 25 26 27]]        [[ 0  1  2  3]  # 0         [ 4  5  6  7]  # 3         [ 8  9 10 11]  # 6         [12 13 14 15]  # 1         [16 17 18 19]  # 2         [20 21 22 23]  # 5         [24 25 26 27]] # 8        '''

‘mod’案例2

   # coding:utf8   import tensorflow as tf   import numpy as np   def test_embedding_lookup():       a = np.arange(12).reshape(3, 4)       b = np.arange(12, 16).reshape(1, 4)       c = np.arange(16, 28).reshape(3, 4)       print(a)       print('\n')       print(b)       print('\n')       print(c)       print('\n')       a = tf.Variable(a)       b = tf.Variable(b)       c = tf.Variable(c)       t = tf.nn.embedding_lookup([a, c, b],           partition_strategy='mod', ids=[0, 3, 6, 1, 4, 7, 2])       init = tf.global_variables_initializer()       sess = tf.Session()       sess.run(init)       m = sess.run(t)       print(m)   test_embedding_lookup()
        '''        分析:         这里我们把params从[a, b, c]改为[a, c, b]这三个tensor组成。a,c,b这三个tensor的第一维度分别为3,3,1。         在把这个三个tensor组合过程中,依旧是每个tensor的元素之间相差len(params).这里分割方式为[a, c, b]  == [[0,3,6], [1,4,7], [2,5,8]]           这里程序还不知道4和7是找不到对应的元素的,在获取元素时候会报错        a=[[ 0  1  2  3]     = [0, 3, 6]  -->  [0  1  2  3]  = 0           [ 4  5  6  7]                  -->  [4  5  6  7]  = 3           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 6        c=[[16 17 18 19]     = [1, 4, 7]  -->  [16 17 18 19]  = 1           [20 21 22 23]                  -->  [20 21 22 23]  = 4           [24 25 26 27]]                -->  [24 25 26 27]  = 7        b=[[12 13 14 15]]    = [2, 5, 8]  -->  [12 13 14 15] = 2                                          -->  运行时报错  = 5                                          -->  运行时报错  = 8              输出:        [[ 0  1  2  3]         [ 4  5  6  7]         [ 8  9 10 11]]        [[12 13 14 15]]        [[16 17 18 19]         [20 21 22 23]         [24 25 26 27]]        [[ 0  1  2  3]  # 0         [ 4  5  6  7]  # 3         [ 8  9 10 11]  # 6         [16 17 18 19]  # 1         [20 21 22 23]  # 4         [24 25 26 27]  # 7         [12 13 14 15]] # 2        '''

‘div’案例1

   # coding:utf8   import tensorflow as tf   import numpy as np   def test_embedding_lookup():       a = np.arange(12).reshape(3, 4)       b = np.arange(12, 16).reshape(1, 4)       c = np.arange(16, 28).reshape(3, 4)       print(a)       print('\n')       print(b)       print('\n')       print(c)       print('\n')       a = tf.Variable(a)       b = tf.Variable(b)       c = tf.Variable(c)       t = tf.nn.embedding_lookup([a, b, c],           partition_strategy='div', ids=[0, 1, 2, 3, 5, 6])       init = tf.global_variables_initializer()       sess = tf.Session()       sess.run(init)       m = sess.run(t)       print(m)   test_embedding_lookup()
        '''        分析:         这里我们把params依旧是[a, b, c],三个tensor的第一维度分别为3,1,3。         在把这个三个tensor组合过程中,这我们按照partition_strategy='div'策略分割。即每个tensor的元素之间相差1.如果不够等分的话,前面(max_id+1)%len(params)多分一个元素。这里一共7个元素,分为3组,即3、2、2分配。         这里分割方式为[a, b, c]  == [[0,1,2], [3,4], [5,6]]           这里程序还不知道4和7是找不到对应的元素的,在获取元素时候会报错        a=[[ 0  1  2  3]     = [0, 1, 2]  -->  [0  1  2  3]  = 0           [ 4  5  6  7]                  -->  [4  5  6  7]  = 1           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 2        b=[[12 13 14 15]]    = [3, 4]  -->  [12 13 14 15] = 3                                          -->  运行时报错  = 4        c=[[16 17 18 19]     = [5, 6]  -->  [16 17 18 19]  = 5           [20 21 22 23]                  -->  [20 21 22 23]  = 6           [24 25 26 27]]                -->  [24 25 26 27]  = 这个是找不到的了           输出:        [[ 0  1  2  3]         [ 4  5  6  7]         [ 8  9 10 11]]        [[12 13 14 15]]        [[16 17 18 19]         [20 21 22 23]         [24 25 26 27]]        [[ 0  1  2  3]  # 0         [ 4  5  6  7]  # 1         [ 8  9 10 11]  # 2         [12 13 14 15]  # 3         [16 17 18 19]  # 5         [20 21 22 23]] # 6        '''

‘div’案例2

   # coding:utf8   import tensorflow as tf   import numpy as np   def test_embedding_lookup():       a = np.arange(12).reshape(3, 4)       b = np.arange(12, 16).reshape(1, 4)       c = np.arange(16, 28).reshape(3, 4)       print(a)       print('\n')       print(b)       print('\n')       print(c)       print('\n')       a = tf.Variable(a)       b = tf.Variable(b)       c = tf.Variable(c)       t = tf.nn.embedding_lookup([a, c, b],           partition_strategy='div', ids=[0, 1, 2, 3, 4, 5])       init = tf.global_variables_initializer()       sess = tf.Session()       sess.run(init)       m = sess.run(t)       print(m)   test_embedding_lookup()
        '''        分析:         这里我们把params改为[a, c, b],三个tensor的第一维度分别为3,3,1。         在把这个三个tensor组合过程中,这我们按照partition_strategy='div'策略分割。这里一共7个元素,分为3组,即3、2、2分配。         这里分割方式为[a, c, b]  == [[0,1,2], [3,4], [5,6]]           这里程序还不知道4和7是找不到对应的元素的,在获取元素时候会报错        a=[[ 0  1  2  3]     = [0, 1, 2]  -->  [0  1  2  3]  = 0           [ 4  5  6  7]                  -->  [4  5  6  7]  = 1           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 2        c=[[16 17 18 19]     = [3, 4]  -->  [16 17 18 19]  = 3           [20 21 22 23]                  -->  [20 21 22 23]  = 4           [24 25 26 27]]                -->  [24 25 26 27]  = 这个是找不到的了           b=[[12 13 14 15]]    = [5, 6]  -->  [12 13 14 15] = 5                                          -->  运行时报错  = 6        输出:        [[ 0  1  2  3]         [ 4  5  6  7]         [ 8  9 10 11]]        [[12 13 14 15]]        [[16 17 18 19]         [20 21 22 23]         [24 25 26 27]]        [[ 0  1  2  3]  # 0         [ 4  5  6  7]  # 1         [ 8  9 10 11]  # 2         [16 17 18 19]  # 3         [20 21 22 23]  # 4         [16 17 18 19]] # 5        '''


参考资料

https://stackoverflow.com/questions/34870614/what-does-tf-nn-embedding-lookup-function-do/41922877#41922877?newreg=5119f86ea49b43aa8988a833294ceb3e

https://www.zhihu.com/question/52250059

原创粉丝点击