tensor2tensor

来源:互联网 发布:suse linux vsftp 编辑:程序博客网 时间:2024/06/01 22:40

tensor2tensor to train all the need is attention model

参数定义:

通过 create_experiment调用train_utils.create_hparams,调用函数problem_hparams.problem_hparams,之后调用transormer,transformer调用common_hparams得到基本的模型参数,并在transformer中补充参数.

模型调用:

trainer_utils._cond_on_index调用fn(cur_idx),fn为trainer_utils.model_fn.nth_model,调用

t2t_model._with_timing.fn_with_timing调用fn,fn调用transformer.model_fn_body得到loss,logits.

训练数据生成:

函数为tensor2tensor/bin/t2t-datagen.py

对于每个任务和数据集,都分别定义了其数据生成函数,所有任务数据处理函数定义在词典_SUPPORTED_PROBLEM_GENERATORS中

例如要训练en-fr的attention is all you need模型,其数据处理函数定义为:

“wmt_enfr_tokens_32k”: ( lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15))

enfr_wordpiece_token_generator定义在wmt中,代码如下:
def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):  """Instance of token generator for the WMT en->fr task."""  symbolizer_vocab = generator_utils.get_or_generate_vocab(      tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size)  datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS  tag = "train" if train else "dev"  data_path = _compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag)  return token_generator(data_path + ".lang1", data_path + ".lang2",                         symbolizer_vocab, 1)

首先generator_utils.get_or_generate_vocab函数生成词典,_ENFR_TRAIN_DATASETS 为包含输入数据的词典,token_generator函数读取训练数据,并将其转换为数字:

def token_generator(source_path, target_path, token_vocab, eos=None):  eos=token_vocab['</S>']  eos_list = [] if eos is None else [eos]  with tf.gfile.GFile(source_path, mode="r") as source_file:    with tf.gfile.GFile(target_path, mode="r") as target_file:      source, target = source_file.readline(), target_file.readline()      while source and target:        source_ints = word_num(source.strip('\n'),token_vocab) + eos_list#change text to number and end with eos1        target_ints = word_num(target.strip('\n'),token_vocab) + eos_list        slen=max(len(source_ints),len(target_ints))        if slen>=5 and slen<=20:#control the training sentence to be [5,20]            yield {"inputs": source_ints, "targets": target_ints}        source, target = source_file.readline(), target_file.readline()