参照 char-rnn-tensorflow,使用RNN的字符模型,学习并生成古诗。
tensorflow
python train.py在使用GPU的情况下,两个小时内即可完成训练
python sample.py
rnn神经网络会生成一首全新的古诗。例如: ”帝以诚求备,堪留百勇杯。教官日与失,共恨五毛宣。鸡唇春疏叶,空衣滴舞衣。丑夫归晚里,此地几何人。”
python sample.py --prime <这里输入指定汉字>
rnn神经网络会利用输入的汉字生成一首藏头诗。例如: python sample.py --prime 如花似月
会得到 “如尔残回号,花枝误晚声。似君星度上,月满二秋寒。
from __future__ import print_function
import numpy as np
import tensorflow as tf
import argparse
import time
import os
from six.moves import cPickle
from utils import TextLoader
from model import Model
from six import text_type
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
parser.add_argument('--prime', type=str, default='',
help=u'输入指定文字生成藏头诗')
parser.add_argument('--sample', type=int, default=1,
help='0 to use max at each timestep, 1 to sample at each timestep')
args = parser.parse_args()
sample(args)
def sample(args):
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, True)
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print(model.sample(sess, chars, vocab, args.prime.decode('utf-8',errors='ignore'),args.sample))
if __name__ == '__main__':
main()