Tensorflow-CSV数据
来源:互联网 发布:淘宝网白色运动鞋 编辑:程序博客网 时间:2024/06/06 07:16
数据使用的是titanic
import tensorflow as tfimport osos.environ["CUDA_VISIBLE_DEVICES"] = "1"print(os.getcwd())#读取函数定义def read_data(file_queue): reader = tf.TextLineReader(skip_header_lines=1) # 跳过标题行 key, value = reader.read(file_queue) #定义列 defaults = [ [0], [0.], [''],[''],[0.], [0],[0],[''],[0.0]] #编码 survived,pclass,name,sex,age,sibsp,parch,ticket,fare = tf.decode_csv(value, defaults) #处理 gender=tf.case({tf.equal(sex,tf.constant('female')):lambda: tf.constant(1.), tf.equal(sex, tf.constant('male')): lambda: tf.constant(0.), }, lambda: tf.constant(-1.), exclusive=True) #栈 features=tf.stack([pclass,gender,age]) return features, survived # 返回 X,Ydef create_pipeline(filename, batch_size, num_epochs=None): file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs) # 放入在文件队列里 example, label = read_data(file_queue) min_after_dequeue = 1000 capacity = min_after_dequeue + batch_size example_batch, label_batch = tf.train.shuffle_batch( [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue ) return example_batch, label_batch # 返回X,Yglobal_step = tf.Variable(0, trainable=False)# learning_rate = 0.1#tf.train.exponential_decay(0.1, global_step, 100, 0.0)# Input layerx = tf.placeholder(tf.float32, [None, 3])y = tf.placeholder(tf.int32, [None])# Output layerw = tf.Variable(tf.random_normal([3, 2])) # 二分类b = tf.Variable(tf.random_normal([2]))# a = tf.matmul(x, w) + b# prediction = tf.nn.softmax(a)def inference(X): return tf.nn.softmax(tf.matmul(X,w)+b)def loss(X,Y): return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=Y,logits=inference(X)))def inputs(): x_train_batch, y_train_batch = create_pipeline('titanic_dataset.csv', 50, num_epochs=1000) return x_train_batch,y_train_batchdef train(total_loss): learning_rate=0.1 return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss, global_step=global_step)def evaluate(X,Y): correct_prediction = tf.equal(tf.argmax(inference(X), 1), tf.cast(y, tf.int64)) return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))x_train_batch, y_train_batch =inputs()cross_entropy=loss(x,y)train_step=train(cross_entropy)accuracy=evaluate(x,y)init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) #with tf.Session() as sess: init.run() #只初始化tf.global_variables_initializer() 会报错,必须还初始化tf.local_variables_initializer() coord = tf.train.Coordinator() # threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 线程 try: print("Training: ") count = 0 # curr_x_test_batch, curr_y_test_batch = sess.run([x_test, y_test]) while not coord.should_stop(): # Run training steps or whatever curr_x_train_batch, curr_y_train_batch = sess.run([x_train_batch, y_train_batch]) # 必须将队列中的值取出,才能放入到feed_dict进行传递 sess.run(train_step, feed_dict={ x: curr_x_train_batch, y: curr_y_train_batch }) count += 1 ce,acc = sess.run([cross_entropy,accuracy], feed_dict={ x: curr_x_train_batch, y: curr_y_train_batch }) if count%100==0: print('Batch:',count,'loss:',ce,'accuracy:',acc) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: # When done, ask the threads to stop. coord.request_stop() # Wait for threads to finish. coord.join(threads) sess.close()
阅读全文
0 0
- Tensorflow-CSV数据
- tensorflow读取数据(csv格式)
- tensorflow读取数据之CSV格式
- TensorFlow 读取CSV数据代码实现
- TensorFlow读取CSV数据的实现
- Tensorflow 读取Txt和Csv格式数据
- Tensorflow | 读取csv文件
- Tensorflow直接读取CSV文件
- tensorflow 输出权重 到csv或txt
- tensorflow将CSV文件转为TFrecords文件
- Tensorflow csv文件读写与分批训练
- 读取.csv文件数据
- python 处理csv数据
- CSV 数据导出保存
- 读写xls csv数据
- cocos2dx解析csv数据
- MySql 导入CSV数据
- 导出csv数据
- HDU6214 Smallest Minimum Cut 【最大流求最小割边】
- HTML/CSS导航菜单-圆角菜单的制作
- 编译和链接
- 好久没刷题了(阿里测试题)
- MySQL索引建立
- Tensorflow-CSV数据
- 使用TFS-如何删除TFS上项目的正确姿势
- 字符串String类型 、数组 Array类型
- UVA 12545 Bits Equalizer 机智题
- C++ list 类学习笔记
- P3390 【模板】矩阵快速幂
- 新建file时,file是否存在的问题
- Spring 的七大模块
- 1023. 组个最小数 (20)