第十三课 wide&deep模型

来源:互联网 发布:对工作表示满意的数据 编辑:程序博客网 时间:2024/05/01 20:26

这篇paper很简单,网上也有很多人翻译过来. 使用tensorflow自带的库,其实很简单。有些难点的地方是,关于特征工程部分的理解。请参考 第九课 tensorflow 特征工程: feature_column

下面是具体的实践demo:

# coding:utf-8"""wide and deep"""from framework.data_input import IDataInputfrom framework.inference import IInferencefrom framework.train import ITrainfrom framework.eval import IEvalimport pandas as pdimport tensorflow as tfimport commonfrom hook import LoggerHookclass DataInput(IDataInput):    def __init__(self,                 csv_column_names,                 label_column_name,                 input_file_paths,                 shuffle,                 batch_size,                 example_per_epoch_num,                 parallel_thread_num=16):        super(DataInput, self).__init__(input_file_paths,                                        batch_size,                                        example_per_epoch_num,                                        parallel_thread_num=parallel_thread_num)        self._csv_column_names = csv_column_names        self._label_column_name = label_column_name        self._shuffle = shuffle    def read_data(self):        input_file_path = self._input_file_paths[0]        df = pd.read_csv(tf.gfile.Open(input_file_path),                         names=self._csv_column_names,                         skipinitialspace=True,                         skiprows=1)        df.dropna(axis=0, how='any')        label = df[self._label_column_name].apply(lambda x: ">50K" in x).astype(int)        print(df.head())        print(label.head())        return tf.estimator.inputs.pandas_input_fn(            x=df,            y=label,            batch_size=self._batch_size,            shuffle=self._shuffle,            num_threads=self._parallel_thread_num,            num_epochs=self._example_per_echo_num        )    def _preprocess_data(self, record):        pass    def _generate_train_batch(self, train_data, label, shuffle=True):        pass    def _read_data_from_queue(self, file_path_queue):        passclass Inference(IInference):    def __init__(self, model_dir, model_type, linear_feature_columns, dnn_feature_columns):        super(Inference, self).__init__()        self._model_dir = model_dir        self._model_type = model_type        self._linear_feature_columns = linear_feature_columns        self._dnn_featrue_columns = dnn_feature_columns    def inference(self, data):        if self._model_type == 'wide':            return tf.estimator.LinearClassifier(feature_columns=self._linear_feature_columns,                                                 model_dir=self._model_dir)        elif self._model_type == 'deep':            return tf.estimator.DNNClassifier(feature_columns=self._dnn_featrue_columns,                                              model_dir=self._model_dir,                                              hidden_units=[100, 50])        elif self._model_type == 'wide_n_deep':            return tf.estimator.DNNLinearCombinedClassifier(model_dir=self._model_dir,                                                            linear_feature_columns=self._linear_feature_columns,                                                            dnn_feature_columns=self._dnn_featrue_columns,                                                            dnn_hidden_units=[100, 50])        else:            raise RuntimeError('no %s model type' % self._model_type)class Train(ITrain):    def __init__(self, model_type, model_dir):        super(Train, self).__init__()        self._model_type = model_type        self._model_dir = model_dir    @property    def model_dir(self):        return self._model_dir    @property    def model_type(self):        return self._model_type    def train(self):        data_input = DataInput(common.CSV_COLUMNS,                               common.LABEL_COLUMN_NAME,                               ['./input/adult.data'],                               shuffle=True,                               batch_size=128,                               example_per_epoch_num=None)        input_fn = data_input.read_data()        if self._model_type == 'wide':            inference = Inference(self._model_dir,                                  self._model_type,                                  common.base_columns + common.crossed_columns, None)        elif self._model_type == 'deep':            inference = Inference(self._model_dir,                                  self._model_type,                                  None,                                  common.deep_columns)        elif self._model_type == 'wide_n_deep':            inference = Inference(self._model_dir,                                  self._model_type,                                  common.crossed_columns,                                  common.deep_columns)        else:            raise RuntimeError('model type error: ' + self._model_type)        model = inference.inference(None)        logger_hook = LoggerHook()        model.train(input_fn=input_fn, hooks=[logger_hook], steps=2000)        return modelclass Eval(IEval):    def __init__(self, model):        super(Eval, self).__init__(None, None, None, 128)        self._model = model    def accuracy(self, predict_results, labels):        pass    def read_test_data_set(self):        pass    def predict(self, test_data_batch):        pass    def eval(self):        data_input = DataInput(common.CSV_COLUMNS,                               common.LABEL_COLUMN_NAME,                               ['./input/adult.test'],                               shuffle=False,                               batch_size=128,                               example_per_epoch_num=1)        test_data_input_fn = data_input.read_data()        results = self._model.evaluate(input_fn=test_data_input_fn,                                       steps=None)        for key in sorted(results):            print("%s: %s" % (key, results[key]))
# coding:utf-8"""common"""import tensorflow as tfGENDER = 'gender'EDUCATION = 'education'MARITAL_STATUS = 'marital_status'RELATIONSHIP = 'relationship'WORK_CLASS = 'workclass'OCCUPATION = 'occupation'NATIVE_COUNTRY = 'native_country'AGE = 'age'EDUCATION_NUM = 'education_num'CAPITAL_GAIN = 'capital_gain'CAPITAL_LOSS = 'capital_loss'HOURS_PER_WEEK = 'hours_per_week'CSV_COLUMNS = [    AGE, WORK_CLASS, "fnlwgt", EDUCATION, EDUCATION_NUM,    MARITAL_STATUS, OCCUPATION, RELATIONSHIP, "race", GENDER,    CAPITAL_GAIN, CAPITAL_LOSS, HOURS_PER_WEEK, NATIVE_COUNTRY,    "income_bracket"]gender = tf.feature_column.categorical_column_with_vocabulary_list(GENDER, ['Female', 'Male'])education = tf.feature_column.categorical_column_with_vocabulary_list(EDUCATION,                                                                      ["Bachelors", "HS-grad", "11th", "Masters",                                                                       "9th",                                                                       "Some-college", "Assoc-acdm", "Assoc-voc",                                                                       "7th-8th",                                                                       "Doctorate", "Prof-school", "5th-6th", "10th",                                                                       "1st-4th",                                                                       "Preschool", "12th"                                                                       ]                                                                      )marital_status = tf.feature_column.categorical_column_with_vocabulary_list(    MARITAL_STATUS, [        "Married-civ-spouse", "Divorced", "Married-spouse-absent",        "Never-married", "Separated", "Married-AF-spouse", "Widowed"    ])relationship = tf.feature_column.categorical_column_with_vocabulary_list(    RELATIONSHIP, [        "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",        "Other-relative"    ])workclass = tf.feature_column.categorical_column_with_vocabulary_list(    WORK_CLASS, [        "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",        "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"    ])occupation = tf.feature_column.categorical_column_with_hash_bucket(    OCCUPATION, hash_bucket_size=1000)native_country = tf.feature_column.categorical_column_with_hash_bucket(NATIVE_COUNTRY,                                                                       hash_bucket_size=1000)age = tf.feature_column.numeric_column(AGE)education_num = tf.feature_column.numeric_column(EDUCATION_NUM)capital_gain = tf.feature_column.numeric_column(CAPITAL_GAIN)capital_loss = tf.feature_column.numeric_column(CAPITAL_LOSS)hours_per_week = tf.feature_column.numeric_column(HOURS_PER_WEEK)age_buckets = tf.feature_column.bucketized_column(age,                                                  boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])base_columns = [    gender, education, marital_status, relationship, workclass, occupation,    native_country, age_buckets,]crossed_columns = [    tf.feature_column.crossed_column([EDUCATION, OCCUPATION], hash_bucket_size=1000),    tf.feature_column.crossed_column(        [age_buckets, EDUCATION, OCCUPATION], hash_bucket_size=1000),    tf.feature_column.crossed_column(        [NATIVE_COUNTRY, OCCUPATION], hash_bucket_size=1000)]deep_columns = [    tf.feature_column.indicator_column(workclass),    tf.feature_column.indicator_column(education),    tf.feature_column.indicator_column(gender),    tf.feature_column.indicator_column(relationship),    tf.feature_column.embedding_column(native_country, dimension=8),    tf.feature_column.embedding_column(occupation, dimension=8),    age,    education_num,    capital_gain,    capital_loss,    hours_per_week]LABEL_COLUMN_NAME = "income_bracket"MODEL_DIR = './output'WIDE_MODEL_DIR = MODEL_DIR + '/wide'DEEP_MODEL_DIR = MODEL_DIR + '/deep'WIDE_N_DEEP_DIR = MODEL_DIR + '/wide_deep'WIDE_MODEL_TYPE = 'wide'DEEP_MODEL_TYPE = 'deep'WIDE_N_DEEP_MODEL_TYPE = 'wide_n_deep'
# coding:utf-8"""hook"""import tensorflow as tfimport timeimport datetimeclass LoggerHook(tf.train.SessionRunHook):    def __init__(self):        super(LoggerHook, self).__init__()        self._step = -1        self._start_time = time.time()        self._log_frequency = 10    def begin(self):        self._step = -1        self._start_time = time.time()        self._log_frequency = 10    def before_run(self, run_context):        self._step += 1        # loss会作为参数一起被运行 会在after_run运行结束后 将run_values 也就是这里的loss值传回        loss_value = tf.get_collection(tf.GraphKeys.LOSSES)        return tf.train.SessionRunArgs(loss_value)    def after_run(self, run_context, run_values):        if self._step % self._log_frequency == 0:            current_time = time.time()            duration = current_time - self._start_time            self._start_time = current_time            loss_value = run_values.results            i = 0            for l in loss_value:                print(i, ':', l)                i += 1            print('-' * 40)            examples_per_sec = self._log_frequency * 128 / duration            sec_per_batch = float(duration / self._log_frequency)            format_str = ('%s: step %d, loss = todo (%.1f examples/sec; %.3f '                          'sec/batch)')            print(format_str % (datetime.datetime.now(), self._step,                                examples_per_sec, sec_per_batch))
# coding:utf-8"""main"""from wide_and_deep import DataInputfrom wide_and_deep import Trainfrom wide_and_deep import Evalimport commonimport loggingimport tensorflow as tfif __name__ == '__main__':    tf.logging.set_verbosity(tf.logging.INFO)    # model_type = common.WIDE_MODEL_TYPE    # model_type = common.DEEP_MODEL_TYPE    model_type = common.WIDE_N_DEEP_MODEL_TYPE    train = None    if model_type == common.WIDE_MODEL_TYPE:        train = Train(common.WIDE_MODEL_TYPE, common.WIDE_MODEL_DIR)    elif model_type == common.DEEP_MODEL_TYPE:        train = Train(common.DEEP_MODEL_TYPE, common.DEEP_MODEL_DIR)    elif model_type == common.WIDE_N_DEEP_MODEL_TYPE:        train = Train(common.WIDE_N_DEEP_MODEL_TYPE, common.WIDE_N_DEEP_DIR)    else:        raise RuntimeError("error model type")    if train is not None:        model = train.train()        wd_eval = Eval(model=model)        wd_eval.eval()
原创粉丝点击