学习笔记:csv文件的读取和tf.contrib.learn Quickstart

来源:互联网 发布:lms仿真软件 编辑:程序博客网 时间:2024/06/05 06:02

tf.contrib.learn 是TensorFlow高层次机器学习API。

以下是TensorFlow官方文档的实例代码解析

from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport itertoolsimport pandas as pdimport tensorflow as tftf.logging.set_verbosity(tf.logging.INFO)#set logging verbosity to INFOCOLUMNS = ["crim", "zn", "indus", "nox", "rm", "age","dis", "tax", "ptratio", "medv"]FEATURES = ["crim", "zn", "indus", "nox", "rm","age", "dis", "tax", "ptratio"]LABEL = "medv"#数据集读取,训练集,测试集,预测集training_set = pd.read_csv("boston_train.csv", skipinitialspace=True, skiprows=1, names=COLUMNS)test_set = pd.read_csv("boston_test.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)prediction_set = pd.read_csv("boston_predict.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)#skipinitialspace : boolean, default False忽略分隔符后的空白(默认为False,即不忽略).#skiprows : list-like or integer, default None需要忽略的行数(从文件开始处算起),或需要跳过的行号列表(从0开始)。#创建特征容量器FeatureColumns,把读入的特征分解为一个列表feature_cols = [tf.contrib.layers.real_valued_column(k) for k in FEATURES]#构建DNN网络,两个隐层,每层10个神经单元regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_cols,                                          hidden_units=[10, 10],                                          model_dir="/tmp/boston_model")#定义输入函数,输入是数据集,返回的是FeatureColumns和labels(标签)def input_fn(data_set):    feature_cols = {k: tf.constant(data_set[k].values) for k in FEATURES}#把feature_cols变为TensorFlow常量    labels = tf.constant(data_set[LABEL].values)    return feature_cols, labels#-------------------------Training the Regressor-------------------------------#迭代5000步 classifer.fit 训练模型regressor.fit(input_fn=lambda: input_fn(training_set), steps=5000)#--------------------------Evaluating the Model----------------------#计算精度ev = regressor.evaluate(input_fn=lambda: input_fn(test_set), steps=1)loss_score = ev["loss"]print("Loss: {0:f}".format(loss_score))#-------------------------Making Predictions-----------------------#输入prediction_set数据集对模型预测y = regressor.predict(input_fn=lambda: input_fn(prediction_set))# .predict() returns an iterator; convert to a list and print predictionspredictions = list(itertools.islice(y, 6))#itertools用于高效循环的迭代函数集合,返回前6个值print ("Predictions: {}".format(str(predictions)))



重要知识点:

tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

classifer.fit 训练模型

classifier.evaluate 评价模型

classifier.predict 预测新样本

函数my_input_fn():返回值frature_cols是一个字典,包含键值对,把列名和数据特征对应起来,返回值labels只是一个包含标签的张量。

1 0
原创粉丝点击