用深度神经网络对Iris数据集进行分类的程序--tensorflow
来源:互联网 发布:java接收post请求数据 编辑:程序博客网 时间:2024/06/01 09:50
先确保你已经安装了tensorflow…
# 引入必要的modulefrom __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport osimport urllibimport numpy as npimport tensorflow as tf# Data setsIRIS_TRAINING = "iris_training.csv"IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"IRIS_TEST = "iris_test.csv"IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"def main(): # If the training and test sets aren't stored locally, download them. if not os.path.exists(IRIS_TRAINING): raw = urllib.urlopen(IRIS_TRAINING_URL).read() with open(IRIS_TRAINING, "w") as f: f.write(raw) if not os.path.exists(IRIS_TEST): raw = urllib.urlopen(IRIS_TEST_URL).read() with open(IRIS_TEST, "w") as f: f.write(raw) # Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32) # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[10, 20, 10],n_classes=3,model_dir="/tmp/iris_model") # Define the training inputs def get_train_inputs(): x = tf.constant(training_set.data) y = tf.constant(training_set.target) return x, y # Fit model. classifier.fit(input_fn=get_train_inputs, steps=2000) # Define the test inputs def get_test_inputs(): x = tf.constant(test_set.data) y = tf.constant(test_set.target) return x, y # Evaluate accuracy. accuracy_score = classifier.evaluate(input_fn=get_test_inputs,steps=1)["accuracy"] print("\nTest Accuracy: {0:f}\n".format(accuracy_score)) # Classify two new flower samples. def new_samples(): return np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=np.float32) predictions = list(classifier.predict(input_fn=new_samples)) print( "New Samples, Class Predictions: {}\n" .format(predictions))if __name__ == "__main__": main()
运行结果:
Test Accuracy: 0.966667New Samples, Class Predictions: [1, 1]
1 0
- 用深度神经网络对Iris数据集进行分类的程序--tensorflow
- c#神经网络,实现对Iris数据集进行分类
- RBF神经网络对iris鸢尾花数据集进行分类识别
- 用DNN对Iris数据分类的代码--tensorflow--logging/monitoring/earlystopping/visualizing
- 用深度神经网络对boston housing data进行回归预测的程序--tensorflow
- 利用BP神经网络分类iris数据集
- 使用不同的SVM对iris数据集进行分类并绘出结果
- [Java][机器学习]用决策树分类算法对Iris花数据集进行处理
- 【python 神经网络】BP神经网络python实现-iris数据集分类
- iris数据集进行KNN分类
- 用随机森林分类算法进行Iris 数据分类训练,是怎样的体验?
- 【深度学习】BP算法分类iris数据集
- (原创2008.07.21)对iris数据进行聚类分析的程序(模式识别)
- Java 实现 BP 神经网络完成 Iris 数据分类
- Spark 机器学习实践 :Iris数据集的分类
- 85、使用TFLearn实现iris数据集的分类
- 使用IRIS数据集训练第一个深度神经网络
- TensorFlow 测试 IRIS 数据
- 数组与指针的理解
- jasper报表工具的使用
- 基于全注解的SpringMVC+Spring4.2+hibernate4.3框架搭建
- mongodb php 增删改查
- Node Express listen和http createServer区别
- 用深度神经网络对Iris数据集进行分类的程序--tensorflow
- 相关指针的理解
- windows7 x64 环境下的 opencv 3.2.0 在qt5.8.0(msvc 2015)上使用的配置
- 如何卸载rpm包
- Java基础———个简单的Java框架
- 如何删除在使用jQuery变量的选项标签?
- Prism中使用MEF(依赖注入)案例
- 改善脑回路1——斐波那契数列
- 关于java maven 项目debug运行时,项目报sourse not found问题