one_hot的源码实现及其理解

来源:互联网 发布:被上苍诅咒的天才 知乎 编辑:程序博客网 时间:2024/06/06 00:24

以下是one_hot 的源码实现:

def dense_to_one_hot(labels_dense, num_classes):  """Convert class labels from scalars to one-hot vectors."""  num_labels = labels_dense.shape[0]  index_offset = numpy.arange(num_labels) * num_classes  labels_one_hot = numpy.zeros((num_labels, num_classes))  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1  return labels_one_hot
测试如下:

>>> labels_dense = np.array([0,1,2])>>> labels_densearray([0, 1, 2])>>> num_classes = 3>>> labels_densearray([0, 1, 2])>>> num_labels = labels_dense.shape[0]>>> num_labels3>>> index_offset = np.arange(num_labels)*num_classes>>> index_offsetarray([0, 3, 6])>>> labels_one_hot = np.zeros((num_labels, num_classes))>>> labels_one_hotarray([[ 0.,  0.,  0.],       [ 0.,  0.,  0.],       [ 0.,  0.,  0.]])>>> labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1>>> labels_one_hotarray([[ 1.,  0.,  0.],       [ 0.,  1.,  0.],       [ 0.,  0.,  1.]])
其实自己计算一遍就可以理解大概了,就是将原始的[0,1,2]标签转化成了以下的形式来提高效率:

[[ 1.,  0.,  0.], [ 0.,  1.,  0.], [ 0.,  0.,  1.]]


阅读全文
1 0
原创粉丝点击