用DNN对Iris数据分类的代码--tensorflow--logging/monitoring/earlystopping/visualizing

来源:互联网 发布:网络推广待遇 编辑:程序博客网 时间:2024/05/18 01:32

本博客是对 用深度神经网络对Iris数据集进行分类的程序–tensorflow
里面的代码进行修改,使其可以记录训练日志,监控训练指标,设置early stopping, 并在TensorBoard中进行可视化.

注意和原程序进行对比,看看增加了哪些code

from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport osimport numpy as npimport tensorflow as tftf.logging.set_verbosity(tf.logging.INFO)# Data setsIRIS_TRAINING = os.path.join(os.path.dirname(__file__), "iris_training.csv")IRIS_TEST = os.path.join(os.path.dirname(__file__), "iris_test.csv")def main(unused_argv):    # Load datasets.    training_set = tf.contrib.learn.datasets.base.load_csv_with_header(        filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32)    test_set = tf.contrib.learn.datasets.base.load_csv_with_header(        filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)    validation_metrics = {    "accuracy":        tf.contrib.learn.MetricSpec(            metric_fn=tf.contrib.metrics.streaming_accuracy,            prediction_key=tf.contrib.learn.PredictionKey.            CLASSES),    "precision":        tf.contrib.learn.MetricSpec(            metric_fn=tf.contrib.metrics.streaming_precision,            prediction_key=tf.contrib.learn.PredictionKey.            CLASSES),    "recall":        tf.contrib.learn.MetricSpec(            metric_fn=tf.contrib.metrics.streaming_recall,            prediction_key=tf.contrib.learn.PredictionKey.            CLASSES)    }    validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(    test_set.data,    test_set.target,    every_n_steps=50,    metrics=validation_metrics,    early_stopping_metric="loss",    early_stopping_metric_minimize=True,    early_stopping_rounds=200)    # Specify that all features have real-value data    feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]    # Build 3 layer DNN with 10, 20, 10 units respectively.    classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,                                                hidden_units=[10, 20, 10],                                                n_classes=3,                                                model_dir="/tmp/iris_model",                                                config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))    # Fit model.    classifier.fit(x=training_set.data,                   y=training_set.target,                   steps=2000,                   monitors=[validation_monitor])    # Evaluate accuracy.    accuracy_score = classifier.evaluate(x=test_set.data,                                         y=test_set.target)["accuracy"]    print('Accuracy: {0:f}'.format(accuracy_score))    # Classify two new flower samples.    new_samples = np.array(        [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)    y = list(classifier.predict(new_samples, as_iterable=True))    print('Predictions: {}'.format(str(y)))if __name__ == "__main__":  tf.app.run()

在命令行输入

$ tensorboard --logdir=/tmp/iris_model/

可以看到loss/accuracy/recall/precision/dnn/global_step等指标的可视化结果

0 0
原创粉丝点击