代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
# file: main.py
# author: JinTian
# time: 11/03/2017 9:53 AM
# Copyright 2017 JinTian. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems
import numpy as np
start_token = 'B'
end_token = 'E'
model_dir = './model/'
corpus_file = './data/poems.txt'
lr = 0.0002
def to_word(predict, vocabs):
predict = predict[0]
predict /= np.sum(predict)
sample = np.random.choice(np.arange(len(predict)), p=predict)
if sample > len(vocabs):
return vocabs[-1]
else:
return vocabs[sample]
def gen_poem(begin_word):
batch_size = 1
print('## loading corpus from %s' % model_dir)
poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
input_data = tf.placeholder(tf.int32, [batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
checkpoint = tf.train.latest_checkpoint(model_dir)
saver.restore(sess, checkpoint)
x = np.array([list(map(word_int_map.get, start_token))])
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x})
word = begin_word or to_word(predict, vocabularies)
poem_ = ''
i = 0
while word != end_token:
poem_ += word
i += 1
if i > 24:
break
x = np.array([[word_int_map[word]]])
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x, end_points['initial_state']: last_state})
word = to_word(predict, vocabularies)
return poem_
def pretty_print_poem(poem_):
poem_sentences = poem_.split('。')
for s in poem_sentences:
if s != '' and len(s) > 10:
print(s + '。')
if __name__ == '__main__':
begin_char = input('## please input the first character:')
poem = gen_poem(begin_char)
pretty_print_poem(poem_=poem)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。