【Tensorflow】报错的坑--could not broadcast input array from shape (100,784) into shape (100)

来源:互联网 发布:交警网络执法直播流程 编辑:程序博客网 时间:2024/06/05 03:27

用tensorflow实现自编码机的时候,关于batch的选择,一直报错:

ValueError: could not broadcast input array from shape (100,784) into shape (100)

定位代码:

batch_test_x = mnist.train.next_batch(batch_size)
可以看到是因为选取数据集时出了问题
在训练的时候,需要选择合适batch的数据集训练,有两种正确的选取方法:

方法一:使用mnist.train.next_batch(batch_size)


在使用mnist.train.next_batch(batch_size)方法选择batch的时候,一定要注意正确的代码是:
batch_test_x, batch_test_y= mnist.train.next_batch(batch_size)

数据集和标签一定要同时选择,可以只将batch_test_x丢进自编码机中训练:
cost,_ = sess.run([cost,optimizer],feed_dict={x:batch_test_x})


方法二:使用自定义函数随机选择数据集的范围

代码为:
def get_random_block_from_data(data, batch_size):    start_index = np.random.randint(0, len(data) - batch_size )    return data[start_index:(start_index + batch_size)]batch_test_x = get_random_block_from_data(x_train, batch_size)c,_ = sess.run([cost,optimizer],feed_dict={x:batch_test_x})

这时候,就可以将batch_test_x丢进自编码机训练了

方法三:
batch = mnist.train.next_batch(50)avg_acc = sess.run(accuracy, feed_dict={x: batch[0], y: batch[1], keep_prob: 1.0})








原创粉丝点击