seq2seq 预测时 feed 自己的数据 (接上文)

来源:互联网 发布:程序员中级考试 答案 编辑:程序博客网 时间:2024/05/16 13:50
def decode(): # 改写这个函数  with tf.Session() as sess:    # Create model and load parameters.    model = create_model(sess, True)    model.batch_size = 1  # We decode one sentence at a time.    # Load vocabularies.    # en_vocab_path = os.path.join(FLAGS.data_dir,    #                              "vocab%d.from" % FLAGS.from_vocab_size)    # fr_vocab_path = os.path.join(FLAGS.data_dir,    #                              "vocab%d.to" % FLAGS.to_vocab_size)    # en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)    # _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)    input = open('chinese_word2id.txt', 'r')    chinese_word2id = {}    while True:        line = input.readline()        if line == None or len(line) == 0:            break        words = line.split(' ')        chinese_word2id[words[0]] = int(words[1].strip('\n'))    input = open('english_word2id.txt', 'r')    english_id2word = {}    while True:        line = input.readline()        if line == None or len(line) == 0:            break        words = line.split(' ')        english_id2word[int(words[1].strip('\n'))] = words[0]    # Decode from standard input.    sys.stdout.write("> ")    sys.stdout.flush()    sentence = sys.stdin.readline()    while sentence:      # Get token-ids for the input sentence.      token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), chinese_word2id)      # Which bucket does it belong to?      bucket_id = len(_buckets) - 1      for i, bucket in enumerate(_buckets):        if bucket[0] >= len(token_ids):          bucket_id = i          break      else:        logging.warning("Sentence truncated: %s", sentence)      # Get a 1-element batch to feed the sentence to the model.      encoder_inputs, decoder_inputs, target_weights = model.get_batch(          {bucket_id: [(token_ids, [])]}, bucket_id)      # Get output logits for the sentence.      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,                                       target_weights, bucket_id, True)      # This is a greedy decoder - outputs are just argmaxes of output_logits.      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]      # If there is an EOS symbol in outputs, cut them at that point.      if data_utils.EOS_ID in outputs:        outputs = outputs[:outputs.index(data_utils.EOS_ID)]      # Print out French sentence corresponding to outputs.      print(" ".join([tf.compat.as_str(english_id2word[output]) for output in outputs]))      print("> ", end="")      sys.stdout.flush()      sentence = sys.stdin.readline()

 数据的样子 chinese_word2id.txt

小时 8
毕业了 16
命运 57
纽约 24
身边 69
时机 75
70岁 39
在 40
本来 44
不要 48
后面 21
然后 66
等待 74
非常 78
准时 35
比 4
_GO 1
早 26
25岁 9
嘲笑 12
放轻松 13
却 19
川普 25
奥巴马 29
安排 47
CEO 31
_PAD 0
三个 43
_UNK 3
看似 61
嫉妒 7
可能 63
也 67
奔跑 80
已经 72
是 79
没有 82
所以 59
活到 5
你的 6
领先 34
另一个人 11
有人 14
落后 46
正在 36
好的 20
结婚 56
50岁 23
为 33
工作 30
变慢 32
每个人 41
正确的 42
或 45
找到 51
当上 52
这 53
五年 58
退休 62
行动 64
关于 70
让 71
开始 73
55岁 18
时区 77
才 86
_EOS 2
前面 10
的 65
自己的 15
90岁 17
加州 22
世界上 38
以 28
他们 37
生命 83
速度 27
你 54
依然 55
走 84
单身 49
他们的 60
去世 68
等了 76
然而 81
但 50
22岁 85

 数据的样子 english_word2id.txt

years 5
go 6
still 7
before 8
25 9
envy 10
Someone 11
based 12
securing 55
3 14
to 15
York 16
might 17
them 18
someone 19
around 20
very 21
Don’t 22
While 23
not 24
world 25
ON 26
Trump 27
RELAX 28
55 29
married 30
50 31
LATE 32
You’re 33
works 34
_GO 1
set 4
Destiny 36
Zone 37
some 38
are 39
New 40
Life 41
for 42
CEO 44
waiting 43
does 46
got 47
be 48
_UNK 3
hours 49
job 50
slow 51
_PAD 0
graduated 52
on 53
about 54
22 13
ahead 56
People 57
Absolutely 58
starts 85
running 105
became 61
TIME 62
or 63
waited 64
own 65
retires 66
another 67
your 68
But 69
their 70
much 71
California 72
Obama 73
you. 74
lived 75
behind 45
but 76
else 77
So 78
Time 79
90 80
this 81
up 82
EARLY 83
while 84
of 59
and 86
_EOS 2
is 87
it 88
single 89
good 90
right 91
at 92
in 93
seem 94
You 95
make 96
TIME, 97
TIME. 98
RACE 99
5 100
They 101
you 102
mock 103
act 104
moment 60
70 106
died 107
a 108
ZONE 109
age 110
everyone 35
the 111
yours 112

0 0