DevilKing's blog

冷灯看剑,剑上几分功名?炉香无需计苍生,纵一穿烟逝,万丈云埋,孤阳还照古陵

0%

Improve Spell Correct

原repo

原先的代码几个点说明:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def remove_conflicting_examples(data):
correct_words, incorrect_words = np.array(data)[:,0], np.array(data)[:,1]
correct_vocab, incorrect_vocab = list(set(correct_words)), list(set(incorrect_words))

try:
i=0
while(i<len(data)):

if data[i][0] in incorrect_vocab or data[i][1] in correct_vocab:
del data[i]
i-=1
i+=1
except:
pass
return data

去掉一些相互冲突的example,梳理相关的数据

1
2
3
4
5
def left_pad(list_):
max_seq_len = 20
ans = np.zeros((max_seq_len), dtype=int)
ans[:len(list_)] = np.array(list_[:max_seq_len] )
return ans

取最大的长度20,其他的进行padding补全

1
2
def string2indexes(word):
return left_pad([ord(char) - 96 for char in list(word)])

根据ascii码序表来进行转换,这里的96是因为a的ascii值为97

1
2
3
4
5
6
7
8
9
10
11
12
13
14
xseq_len = train[0].shape[-1]
yseq_len = train[0].shape[-1]
xvocab_size = 26+1
yvocab_size = xvocab_size
embed_size = 300

lstm_cell_size = embed_size
lstm_layers = 3
embed_size = embed_size

epoch_n = 5
lr = 0.01

global_iteration=0

这里的一些参数,很明显了,xseq_len和yseq_len基本都是max_seq_len, xvocab_size就是26个字母+1这样子

1
2
3
4
5
6
7
8
9
with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE) as scope:
# scope.reuse_variables()
decode_outputs, _ = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(enc_ip,dec_ip, lstm,xvocab_size,yvocab_size,
embedding_size=embed_size,feed_previous=False)
scope.reuse_variables() #sharing parameter b/w train and test decoders
decode_outputs_test, _ = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(enc_ip, dec_ip,
lstm, xvocab_size,
yvocab_size,embed_size,
feed_previous =True)

这里,要注意的是关于variable_scope部分中,lstm的resuse使用标志

原始代码里predict出来的,

1
2
3
X:  [16 19 25  3 15 12 15  7  9 19 20  0  0  0  0  0  0  0  0  0]
Prediction: [16 16 16 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18]
Y: [16 19 25 3 8 15 12 15 7 9 19 20 0 0 0 0 0 0 0 0]

prediction出来的内容,基本很大程度上对不上。。。

改进一

  • embedding_rnn_seq2seq改成embedding_attention_seq2seq
  • 采用random的方式,对train的数据集进行训练
  • 从int32变成float32?
  • 关于prediction的展示部分,如何显示出来?