机器学习6-tensorflow

来源:互联网 发布:ubuntu桌面 编辑:程序博客网 时间:2024/05/16 09:21

这次我们使用tensorflow来区分iris

python代码

# coding=utf-8from sklearn import metrics,model_selectionimport tensorflow as tffrom tensorflow.contrib import learn# 获取鸢尾数据iris = learn.datasets.load_dataset('iris')X_train,X_test,y_train,y_test = model_selection.train_test_split(iris.data,iris.target,test_size=.5,random_state=42)# print irisfeature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]print feature_columnsprint tf.contrib.layersclf = learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10,20,10],n_classes=3,model_dir="/tmp/iris_model")clf.fit(x=X_train,y=y_train,steps=2000)predictions = clf.predict(x=X_test)# 评分print clf.evaluate(x=X_test,y=y_test)["accuracy"]

匹配(大概值)

0.96匹配

0 0
原创粉丝点击