基于 ModelArts的古诗词自动生成

举报
步行植物 发表于 2021/03/17 20:34:04 2021/03/17
【摘要】 在中国文化传统中,诗有着极为独特而崇高的地位。诗歌开拓了人类的精神世界,给人们带来了无限的美感。本文将介绍如何使用一站式AI开发平台,自动生成属于你的藏头诗。 环境准备 基于华为云一站式AI开发平台ModelArtsModelArts: https://www.huaweicloud.cn/product/modelarts.htmlAI开发平台ModelArts是面向开发者的一站式AI开...

在中国文化传统中,诗有着极为独特而崇高的地位。诗歌开拓了人类的精神世界,给人们带来了无限的美感。本文将介绍如何使用一站式AI开发平台,自动生成属于你的藏头诗。

环境准备

基于华为云一站式AI开发平台ModelArts


ModelArts: https://www.huaweicloud.cn/product/modelarts.html
AI开发平台ModelArts是面向开发者的一站式AI开发平台,为机器学习与深度学习提供海量数据预处理及半自动化标注、大规模分布式Training、自动化模型生成,及端-边-云模型按需部署能力,帮助用户快速创建和部署模型,管理全周期AI工作流。

对象存储服务OBS


OBS: https://www.huaweicloud.cn/product/obs.html
对象存储服务(Object Storage Service,OBS)提供海量、安全、高可靠、低成本的数据存储能力,可供用户存储任意类型和大小的数据。适合企业备份/归档、视频点播、视频监控等多种数据存储场景。

模型和素材准备

文件已经上传至obs共享桶,在notebook中使用代码可以直接读取。
源文件地址:https://github.com/jinfagang/tensorflow_poems

实际操作

首先,在ModelArts中创建开发环境:在“开发环境”选项下选择notebook。

创建一个notebook,打开JupyterLab,选择tensorflow环境,开始体验。

使用华为云提供的接口,使用代码将公有桶poems中的文件拷贝到本地work路径下:

import moxing as mox
import os
obspath = 'obs://poems/poems/'  #目标文件夹
localpath = os.path.join(os.environ['HOME'],'work/test/')  #本地文件夹
mox.file.copy_parallel(obspath ,localpath) #批量拷贝obs://poems


安装指定版本numpy:

!pip install --upgrade pip
!pip install numpy==1.16.0

导入包,定义train函数:

import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems, generate_batch
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('f', '', 'kernel')

def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)
    poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        start_epoch = 0
        checkpoint = ("./model/poems-42")
        if checkpoint:
            saver.restore(sess, "./model/poems-42")
            print("## restore from the checkpoint {0}".format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## start training...')
        try:
            n_chunk = len(poems_vector) // FLAGS.batch_size
            for epoch in range(start_epoch, FLAGS.epochs):
                n = 0
                for batch in range(n_chunk):
                    loss, _, _ = sess.run([
                        end_points['total_loss'],
                        end_points['last_state'],
                        end_points['train_op']
                    ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
                    n += 1
                    if batch%50==0:
                        print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                if epoch % 6 == 0:
                    saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
            print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))

开始训练(也可以跳过训练,直接调用模型42进行预测):

def main():
    run_training()
if __name__ == '__main__':
    main()

导入预测相关包并加载checkpoints:

import numpy as np
start_token = 'B'
end_token = 'E'
model_dir = './model/'
corpus_file = './data/poems.txt'
lr = 0.0002

def to_word(predict, vocabs):
predict = predict[0]       
predict /= np.sum(predict)
sample = np.random.choice(np.arange(len(predict)), p=predict)
if sample > len(vocabs):
    return vocabs[-1]
else:
    return vocabs[sample]
def gen_poem(begin_word):
tf.reset_default_graph()
batch_size = 1
print('## loading corpus from %s' % model_dir)
poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
input_data = tf.placeholder(tf.int32, [batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
    vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)#,reuse=True
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
    sess.run(init_op)
    saver.restore(sess, "./model/poems-48")
    x = np.array([list(map(word_int_map.get, start_token))])
    [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                     feed_dict={input_data: x})
    word = begin_word or to_word(predict, vocabularies)
    poem_ = ''
    i = 0
    while word != end_token:
        poem_ += word
        i += 1
        if i > 24:
            break
        x = np.array([[word_int_map[word]]])
        [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
                                         feed_dict={input_data: x, end_points['initial_state']: last_state})
        word = to_word(predict, vocabularies)
    return poem_
def pretty_print_poem(poem_):
poem_sentences = poem_.split('。')
for s in poem_sentences:
    if s != '' and len(s) > 10:
        print(s + '。')

调用模型生成诗歌

poem = gen_poem('人')
pretty_print_poem(poem_=poem)

至此,本次实现先告一段落,关于多个字的藏头诗生成还没进行探索,欢迎在评论区分享指导。
另外,有兴趣的小伙伴欢迎加入MDG中国矿业大学站,QQ群:781169338,共建 ModelArts 生态!

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。