tensorflow随笔1

来源:互联网 发布:暗黑钻油井升级数据 编辑:程序博客网 时间:2024/05/18 11:24

昨天想用深度神经网络跑一些数据,踩到了一些坑,记录下来。

from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport tensorflow as tfimport numpy as np# 设定数据集的位置IRIS_TRAINING = "iris_training.csv"IRIS_TEST = "iris_test.csv"# 使用Tensorflow内置的方法进行数据加载training_set =tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING,target_dtype=np.int,features_dtype=np.float32)test_set=tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TEST,target_dtype=np.int,features_dtype=np.float32)# 每行数据4个特征,都是real-value的feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]# 构建一个DNN分类器,3层,其中每个隐含层的节点数量分别为10,20,10,目标的分类3个,并且指定了保存位置classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,                                            hidden_units=[10, 20, 10],                                            n_classes=3,                                            model_dir="/tmp/iris_model")# 指定数据,以及训练的步数classifier.fit(x=training_set.data,               y=training_set.target,               steps=2000)# 模型评估accuracy_score = classifier.evaluate(x=test_set.data,                                     y=test_set.target)["accuracy"]print('Accuracy: {0:f}'.format(accuracy_score))# 直接创建数据来进行预测new_samples = np.array(    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)# y = classifier.predict(new_samples)y = list(classifier.predict(new_samples, as_iterable=True))print('Predictions: {}'.format(str(y)))

这是网上找的TensorFlow很入门的代码,对iris数据集分类,比较这个数据集很出名,类似于Luna的照片,好吧,可以用机器跑一跑,正确率达到0.966667,觉得很奇怪,用一般的选用BPF的Svm或者Lr跑,正确率在0.99左右,感觉这个隐层设置有点问题。和别人跑的结果不一样,这一点应该好好研究,同时与别人交流的时候,这个高级的设置方法,不用为好,

第二个坑

tensoflow版本的问题,

tf.contrib.learn.datasets.base.load_csv_with_header
load csv的方法在新的版本里变成了上面所示,同时可以指定csv文件是否包含header。

原创粉丝点击