第十三课 wide&deep模型
来源:互联网 发布:对工作表示满意的数据 编辑:程序博客网 时间:2024/05/01 20:26
这篇paper很简单,网上也有很多人翻译过来. 使用tensorflow自带的库,其实很简单。有些难点的地方是,关于特征工程部分的理解。请参考 第九课 tensorflow 特征工程: feature_column
下面是具体的实践demo:
# coding:utf-8"""wide and deep"""from framework.data_input import IDataInputfrom framework.inference import IInferencefrom framework.train import ITrainfrom framework.eval import IEvalimport pandas as pdimport tensorflow as tfimport commonfrom hook import LoggerHookclass DataInput(IDataInput): def __init__(self, csv_column_names, label_column_name, input_file_paths, shuffle, batch_size, example_per_epoch_num, parallel_thread_num=16): super(DataInput, self).__init__(input_file_paths, batch_size, example_per_epoch_num, parallel_thread_num=parallel_thread_num) self._csv_column_names = csv_column_names self._label_column_name = label_column_name self._shuffle = shuffle def read_data(self): input_file_path = self._input_file_paths[0] df = pd.read_csv(tf.gfile.Open(input_file_path), names=self._csv_column_names, skipinitialspace=True, skiprows=1) df.dropna(axis=0, how='any') label = df[self._label_column_name].apply(lambda x: ">50K" in x).astype(int) print(df.head()) print(label.head()) return tf.estimator.inputs.pandas_input_fn( x=df, y=label, batch_size=self._batch_size, shuffle=self._shuffle, num_threads=self._parallel_thread_num, num_epochs=self._example_per_echo_num ) def _preprocess_data(self, record): pass def _generate_train_batch(self, train_data, label, shuffle=True): pass def _read_data_from_queue(self, file_path_queue): passclass Inference(IInference): def __init__(self, model_dir, model_type, linear_feature_columns, dnn_feature_columns): super(Inference, self).__init__() self._model_dir = model_dir self._model_type = model_type self._linear_feature_columns = linear_feature_columns self._dnn_featrue_columns = dnn_feature_columns def inference(self, data): if self._model_type == 'wide': return tf.estimator.LinearClassifier(feature_columns=self._linear_feature_columns, model_dir=self._model_dir) elif self._model_type == 'deep': return tf.estimator.DNNClassifier(feature_columns=self._dnn_featrue_columns, model_dir=self._model_dir, hidden_units=[100, 50]) elif self._model_type == 'wide_n_deep': return tf.estimator.DNNLinearCombinedClassifier(model_dir=self._model_dir, linear_feature_columns=self._linear_feature_columns, dnn_feature_columns=self._dnn_featrue_columns, dnn_hidden_units=[100, 50]) else: raise RuntimeError('no %s model type' % self._model_type)class Train(ITrain): def __init__(self, model_type, model_dir): super(Train, self).__init__() self._model_type = model_type self._model_dir = model_dir @property def model_dir(self): return self._model_dir @property def model_type(self): return self._model_type def train(self): data_input = DataInput(common.CSV_COLUMNS, common.LABEL_COLUMN_NAME, ['./input/adult.data'], shuffle=True, batch_size=128, example_per_epoch_num=None) input_fn = data_input.read_data() if self._model_type == 'wide': inference = Inference(self._model_dir, self._model_type, common.base_columns + common.crossed_columns, None) elif self._model_type == 'deep': inference = Inference(self._model_dir, self._model_type, None, common.deep_columns) elif self._model_type == 'wide_n_deep': inference = Inference(self._model_dir, self._model_type, common.crossed_columns, common.deep_columns) else: raise RuntimeError('model type error: ' + self._model_type) model = inference.inference(None) logger_hook = LoggerHook() model.train(input_fn=input_fn, hooks=[logger_hook], steps=2000) return modelclass Eval(IEval): def __init__(self, model): super(Eval, self).__init__(None, None, None, 128) self._model = model def accuracy(self, predict_results, labels): pass def read_test_data_set(self): pass def predict(self, test_data_batch): pass def eval(self): data_input = DataInput(common.CSV_COLUMNS, common.LABEL_COLUMN_NAME, ['./input/adult.test'], shuffle=False, batch_size=128, example_per_epoch_num=1) test_data_input_fn = data_input.read_data() results = self._model.evaluate(input_fn=test_data_input_fn, steps=None) for key in sorted(results): print("%s: %s" % (key, results[key]))
# coding:utf-8"""common"""import tensorflow as tfGENDER = 'gender'EDUCATION = 'education'MARITAL_STATUS = 'marital_status'RELATIONSHIP = 'relationship'WORK_CLASS = 'workclass'OCCUPATION = 'occupation'NATIVE_COUNTRY = 'native_country'AGE = 'age'EDUCATION_NUM = 'education_num'CAPITAL_GAIN = 'capital_gain'CAPITAL_LOSS = 'capital_loss'HOURS_PER_WEEK = 'hours_per_week'CSV_COLUMNS = [ AGE, WORK_CLASS, "fnlwgt", EDUCATION, EDUCATION_NUM, MARITAL_STATUS, OCCUPATION, RELATIONSHIP, "race", GENDER, CAPITAL_GAIN, CAPITAL_LOSS, HOURS_PER_WEEK, NATIVE_COUNTRY, "income_bracket"]gender = tf.feature_column.categorical_column_with_vocabulary_list(GENDER, ['Female', 'Male'])education = tf.feature_column.categorical_column_with_vocabulary_list(EDUCATION, ["Bachelors", "HS-grad", "11th", "Masters", "9th", "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th", "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th", "Preschool", "12th" ] )marital_status = tf.feature_column.categorical_column_with_vocabulary_list( MARITAL_STATUS, [ "Married-civ-spouse", "Divorced", "Married-spouse-absent", "Never-married", "Separated", "Married-AF-spouse", "Widowed" ])relationship = tf.feature_column.categorical_column_with_vocabulary_list( RELATIONSHIP, [ "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried", "Other-relative" ])workclass = tf.feature_column.categorical_column_with_vocabulary_list( WORK_CLASS, [ "Self-emp-not-inc", "Private", "State-gov", "Federal-gov", "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked" ])occupation = tf.feature_column.categorical_column_with_hash_bucket( OCCUPATION, hash_bucket_size=1000)native_country = tf.feature_column.categorical_column_with_hash_bucket(NATIVE_COUNTRY, hash_bucket_size=1000)age = tf.feature_column.numeric_column(AGE)education_num = tf.feature_column.numeric_column(EDUCATION_NUM)capital_gain = tf.feature_column.numeric_column(CAPITAL_GAIN)capital_loss = tf.feature_column.numeric_column(CAPITAL_LOSS)hours_per_week = tf.feature_column.numeric_column(HOURS_PER_WEEK)age_buckets = tf.feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])base_columns = [ gender, education, marital_status, relationship, workclass, occupation, native_country, age_buckets,]crossed_columns = [ tf.feature_column.crossed_column([EDUCATION, OCCUPATION], hash_bucket_size=1000), tf.feature_column.crossed_column( [age_buckets, EDUCATION, OCCUPATION], hash_bucket_size=1000), tf.feature_column.crossed_column( [NATIVE_COUNTRY, OCCUPATION], hash_bucket_size=1000)]deep_columns = [ tf.feature_column.indicator_column(workclass), tf.feature_column.indicator_column(education), tf.feature_column.indicator_column(gender), tf.feature_column.indicator_column(relationship), tf.feature_column.embedding_column(native_country, dimension=8), tf.feature_column.embedding_column(occupation, dimension=8), age, education_num, capital_gain, capital_loss, hours_per_week]LABEL_COLUMN_NAME = "income_bracket"MODEL_DIR = './output'WIDE_MODEL_DIR = MODEL_DIR + '/wide'DEEP_MODEL_DIR = MODEL_DIR + '/deep'WIDE_N_DEEP_DIR = MODEL_DIR + '/wide_deep'WIDE_MODEL_TYPE = 'wide'DEEP_MODEL_TYPE = 'deep'WIDE_N_DEEP_MODEL_TYPE = 'wide_n_deep'
# coding:utf-8"""hook"""import tensorflow as tfimport timeimport datetimeclass LoggerHook(tf.train.SessionRunHook): def __init__(self): super(LoggerHook, self).__init__() self._step = -1 self._start_time = time.time() self._log_frequency = 10 def begin(self): self._step = -1 self._start_time = time.time() self._log_frequency = 10 def before_run(self, run_context): self._step += 1 # loss会作为参数一起被运行 会在after_run运行结束后 将run_values 也就是这里的loss值传回 loss_value = tf.get_collection(tf.GraphKeys.LOSSES) return tf.train.SessionRunArgs(loss_value) def after_run(self, run_context, run_values): if self._step % self._log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results i = 0 for l in loss_value: print(i, ':', l) i += 1 print('-' * 40) examples_per_sec = self._log_frequency * 128 / duration sec_per_batch = float(duration / self._log_frequency) format_str = ('%s: step %d, loss = todo (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.datetime.now(), self._step, examples_per_sec, sec_per_batch))
# coding:utf-8"""main"""from wide_and_deep import DataInputfrom wide_and_deep import Trainfrom wide_and_deep import Evalimport commonimport loggingimport tensorflow as tfif __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) # model_type = common.WIDE_MODEL_TYPE # model_type = common.DEEP_MODEL_TYPE model_type = common.WIDE_N_DEEP_MODEL_TYPE train = None if model_type == common.WIDE_MODEL_TYPE: train = Train(common.WIDE_MODEL_TYPE, common.WIDE_MODEL_DIR) elif model_type == common.DEEP_MODEL_TYPE: train = Train(common.DEEP_MODEL_TYPE, common.DEEP_MODEL_DIR) elif model_type == common.WIDE_N_DEEP_MODEL_TYPE: train = Train(common.WIDE_N_DEEP_MODEL_TYPE, common.WIDE_N_DEEP_DIR) else: raise RuntimeError("error model type") if train is not None: model = train.train() wd_eval = Eval(model=model) wd_eval.eval()
阅读全文
0 0
- 第十三课 wide&deep模型
- tensorflow线性模型以及Wide deep learning
- tensorflow线性模型以及Wide deep learning
- TensorFlow Wide And Deep 模型详解与应用
- Google Wide&&Deep Model
- wide & deep论文-----2016.6.24
- TensorFlow Wide And Deep 模型详解与应用 TensorFlow Wide-And-Deep 阅读344 作者简介:汪剑,现在在出门问问负责推荐与个性化。曾在微软雅虎工作,
- 《Wide & Deep Learning for Recommender Systems 》笔记
- Deep&Wide Learning论文阅读笔记
- 学习笔记:TensorFlow Wide & Deep Learning Tutorial
- 论文阅读——Wide & Deep Learning
- 《Wide & Deep Learning for Recommender Systems》
- WIDE的模型语言研究
- 《Wide and Deep Learning for Recommender Systems》学习笔记
- TensorFlow学习笔记9----TensorFlow Wide & Deep Learning Tutorial
- 论文笔记:Wide & Deep Learning for Recommender Systems
- 《Wide & Deep Learning for Recommender Systems》 学习记录
- TensorFlow Wide & Deep Learning 中遇到的bug
- 利用samba服务器添加网络位置共享[centOS7]
- python 操作 mysql数据库
- cookie和session笔记
- android之 Gestures(手势)
- 一元操作符重载 ,请注意前置和后置会有区别,友元函数的真正用途
- 第十三课 wide&deep模型
- 许久的第一篇
- mysql中exists与in的使用
- 支持向量机SVM(3)
- 把旧的工作负载放在过滤器上
- Spark 常用算子
- 数字签名算法ECDSA
- Shiro 加密解密
- Java学习9:多态