用DNN对Iris数据分类的代码--tensorflow--logging/monitoring/earlystopping/visualizing
来源:互联网 发布:网络推广待遇 编辑:程序博客网 时间:2024/05/18 01:32
本博客是对 用深度神经网络对Iris数据集进行分类的程序–tensorflow
里面的代码进行修改,使其可以记录训练日志,监控训练指标,设置early stopping, 并在TensorBoard中进行可视化.
注意和原程序进行对比,看看增加了哪些code
from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport osimport numpy as npimport tensorflow as tftf.logging.set_verbosity(tf.logging.INFO)# Data setsIRIS_TRAINING = os.path.join(os.path.dirname(__file__), "iris_training.csv")IRIS_TEST = os.path.join(os.path.dirname(__file__), "iris_test.csv")def main(unused_argv): # Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32) validation_metrics = { "accuracy": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_accuracy, prediction_key=tf.contrib.learn.PredictionKey. CLASSES), "precision": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_precision, prediction_key=tf.contrib.learn.PredictionKey. CLASSES), "recall": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_recall, prediction_key=tf.contrib.learn.PredictionKey. CLASSES) } validation_monitor = tf.contrib.learn.monitors.ValidationMonitor( test_set.data, test_set.target, every_n_steps=50, metrics=validation_metrics, early_stopping_metric="loss", early_stopping_metric_minimize=True, early_stopping_rounds=200) # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model", config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1)) # Fit model. classifier.fit(x=training_set.data, y=training_set.target, steps=2000, monitors=[validation_monitor]) # Evaluate accuracy. accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"] print('Accuracy: {0:f}'.format(accuracy_score)) # Classify two new flower samples. new_samples = np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) y = list(classifier.predict(new_samples, as_iterable=True)) print('Predictions: {}'.format(str(y)))if __name__ == "__main__": tf.app.run()
在命令行输入
$ tensorboard --logdir=/tmp/iris_model/
可以看到loss/accuracy/recall/precision/dnn/global_step等指标的可视化结果
0 0
- 用DNN对Iris数据分类的代码--tensorflow--logging/monitoring/earlystopping/visualizing
- 用深度神经网络对Iris数据集进行分类的程序--tensorflow
- TensorFlow 测试 IRIS 数据
- 使用不同的SVM对iris数据集进行分类并绘出结果
- [Java][机器学习]用决策树分类算法对Iris花数据集进行处理
- c#神经网络,实现对Iris数据集进行分类
- 用随机森林分类算法进行Iris 数据分类训练,是怎样的体验?
- Tensorflow:softmax处理Iris鸾尾花分类
- Spark 机器学习实践 :Iris数据集的分类
- 85、使用TFLearn实现iris数据集的分类
- RBF神经网络对iris鸢尾花数据集进行分类识别
- 利用BP神经网络分类iris数据集
- iris数据集进行KNN分类
- Scikit-Learn 实战 iris数据集分类
- 手把手教你如何用 TensorFlow 实现基于 DNN 的文本分类
- 手把手教你如何用 TensorFlow 实现基于 DNN 的文本分类
- Iris动画效果的代码
- 对DNN的理解
- sql求加权平均值
- Windows7 下Apache-Tomcat7.0的安装配置
- redis操作命令详解
- 如何查看计算机隐藏的文件夹
- 汽车钥匙秘钥接收解码
- 用DNN对Iris数据分类的代码--tensorflow--logging/monitoring/earlystopping/visualizing
- ssm+freemark集成shiro
- 1-RabbitMQ安装及简单实例
- VMware非正常关闭导致打开虚拟机时提示:未找到.vmx文件
- linux服务器安装并在windows下访问tomcat服务
- Oracle数据库表中查询最大值和第二大值
- C# 如何在Word文档中添加,替换和删除书签
- SWIFT语言之运算符
- jenkins部署jar项目、springboot项目部署