Skip to content
gqlxj1987's Blog
Go back

Improve Spell Correct

Edit page

原repo

原先的代码几个点说明:

def remove_conflicting_examples(data):
	correct_words, incorrect_words = np.array(data)[:,0], np.array(data)[:,1]
	correct_vocab, incorrect_vocab = list(set(correct_words)), list(set(incorrect_words))
   
	try:
		i=0
		while(i<len(data)):
			
			if data[i][0] in incorrect_vocab or data[i][1] in correct_vocab:
				del data[i]
				i-=1
			i+=1
	except:
		pass
	return data

去掉一些相互冲突的example,梳理相关的数据

def left_pad(list_):
    max_seq_len = 20
    ans = np.zeros((max_seq_len), dtype=int)
    ans[:len(list_)] = np.array(list_[:max_seq_len] )
    return ans

取最大的长度20,其他的进行padding补全

def string2indexes(word):
    return left_pad([ord(char) - 96 for char in list(word)])

根据ascii码序表来进行转换,这里的96是因为a的ascii值为97

xseq_len = train[0].shape[-1]
yseq_len = train[0].shape[-1]
xvocab_size = 26+1
yvocab_size = xvocab_size
embed_size = 300

lstm_cell_size = embed_size
lstm_layers = 3
embed_size = embed_size 

epoch_n = 5
lr = 0.01  

global_iteration=0

这里的一些参数,很明显了,xseq_len和yseq_len基本都是max_seq_len, xvocab_size就是26个字母+1这样子

with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE) as scope:
#     scope.reuse_variables()
    decode_outputs, _ = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(enc_ip,dec_ip,                                                           										lstm,xvocab_size,yvocab_size,
                                                           embedding_size=embed_size,feed_previous=False)
    scope.reuse_variables() #sharing parameter b/w train and test decoders
    decode_outputs_test, _ = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(enc_ip, dec_ip,
                                                                              lstm, xvocab_size,
                                                                              yvocab_size,embed_size,
                                                                              feed_previous =True)

这里,要注意的是关于variable_scope部分中,lstm的resuse使用标志

原始代码里predict出来的,

X:  [16 19 25  3 15 12 15  7  9 19 20  0  0  0  0  0  0  0  0  0]
Prediction:  [16 16 16 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18 18]
Y:  [16 19 25  3  8 15 12 15  7  9 19 20  0  0  0  0  0  0  0  0]

prediction出来的内容,基本很大程度上对不上。。。

改进一


Edit page
Share this post on:

Previous Post
Neural Machine Translation
Next Post
Who will steal Android from Google