博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
引入Redis|tensorflow实现 聊天AI--PigPig养成记(3)
阅读量:6331 次
发布时间:2019-06-22

本文共 4681 字,大约阅读时间需要 15 分钟。

引入Redis

在集成Netty之后,为了提高效率,我打算将消息存储在Redis缓存系统中,本节将介绍Redis在项目中的引入,以及前端界面的开发。

引入Redis后,。

想要直接得到训练了13000步的聊天机器人可以直接下载中

在这里插入图片描述
这三个文件,以及词汇表文件
在这里插入图片描述
然后直接运行连接中的py脚本进行测试即可。
在这里插入图片描述

最终实现效果如下:

在这里插入图片描述

在Netty中引入Redis

import java.io.BufferedReader;import java.io.BufferedWriter;import java.io.File;import java.io.FileNotFoundException;import java.io.FileReader;import java.io.FileWriter;import java.io.IOException;import java.time.LocalDateTime;import io.netty.channel.ChannelHandlerContext;import io.netty.channel.SimpleChannelInboundHandler;import io.netty.channel.group.ChannelGroup;import io.netty.channel.group.DefaultChannelGroup;import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;import io.netty.util.concurrent.GlobalEventExecutor;import redis.clients.jedis.Jedis;public class ChatHandler     extends SimpleChannelInboundHandler
{ private static ChannelGroup clients= new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); @Override protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception { System.out.println("channelRead0..."); //连接redis Jedis jedis=new Jedis("localhost"); System.out.println("连接成功..."); System.out.println("服务正在运行:"+jedis.ping()); //得到用户输入的消息,需要写入文件/缓存中,让AI进行读取 String content=msg.text(); if(content==null||content=="") { System.out.println("content 为null"); return ; } System.out.println("接收到的消息:"+content); //写入缓存中 jedis.set("user_say", content+":user"); Thread.sleep(1000); //读取AI返回的内容 String AIsay=null; while(AIsay=="no"||AIsay==null) { //从缓存中读取AI回复的内容 AIsay=jedis.get("ai_say"); String [] arr=AIsay.split(":"); AIsay=arr[0]; } //读取后马上向缓存中写入 jedis.set("ai_say", "no"); //没有说,或者还没说 if(AIsay==null||AIsay=="") { System.out.println("AIsay==null||AIsay==\"\""); return; } System.out.println("AI说:"+AIsay); clients.writeAndFlush( new TextWebSocketFrame( "AI_PigPig在"+LocalDateTime.now() +"说:"+AIsay)); } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { System.out.println("add..."); clients.add(ctx.channel()); } @Override public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { System.out.println("客户端断开,channel对应的长id为:" +ctx.channel().id().asLongText()); System.out.println("客户端断开,channel对应的短id为:" +ctx.channel().id().asShortText()); } }

在Python中引入Redis

with tf.Session() as sess:#打开作为一次会话    # 恢复前一次训练    ckpt = tf.train.get_checkpoint_state('.')#从检查点文件中返回一个状态(ckpt)    #如果ckpt存在,输出模型路径    if ckpt != None:        print(ckpt.model_checkpoint_path)        model.saver.restore(sess, ckpt.model_checkpoint_path)#储存模型参数    else:        print("没找到模型")    r.set('user_say','no')    #测试该模型的能力    while True:        line='no'        #从缓存中进行读取        while line=='no':            line=r.get('user_say').decode()            #print(line)        list1=line.split(':')        if len(list1)==1:            input_string='no'        else:            input_string=list1[0]            r.set('user_say','no')                                      # 退出        if input_string == 'quit':           exit()        if input_string != 'no':            input_string_vec = []#输入字符串向量化            for words in input_string.strip():                input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函数:如果words在词表中,返回索引号;否则,返回UNK_ID                bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大于输入的bucket的id                encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)                #get_batch(A,B):两个参数,A为大小为len(buckets)的元组,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)                #得到其输出                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的预测范围列表                if EOS_ID in outputs:#如果EOS_ID在输出内部,则输出列表为[,,,,:End]                    outputs = outputs[:outputs.index(EOS_ID)]                             response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#转为解码词汇分别添加到回复中                print('AI-PigPig > ' + response)#输出回复                #向缓存中进行写入                r.set('ai_say',response+':AI')

下一节将讲述通信规则的制定,以规范应用程序。

转载地址:http://iuboa.baihongyu.com/

你可能感兴趣的文章
log4j Test
查看>>
HDU 1255 覆盖的面积(矩形面积交)
查看>>
Combinations
查看>>
SQL数据库无法附加,提示 MDF" 已压缩,但未驻留在只读数据库或文件组中。必须将此文件解压缩。...
查看>>
第二十一章流 3用cin输入
查看>>
在workflow中,无法为实例 ID“...”传递接口类型“...”上的事件“...” 问题的解决方法。...
查看>>
获取SQL数据库中的数据库名、所有表名、所有字段名、列描述
查看>>
Orchard 视频资料
查看>>
简述:预处理、编译、汇编、链接
查看>>
调试网页PAIP HTML的调试与分析工具
查看>>
路径工程OpenCV依赖文件路径自动添加方法
查看>>
玩转SSRS第七篇---报表订阅
查看>>
WinCE API
查看>>
POJ 3280 Cheapest Palindrome(DP 回文变形)
查看>>
oracle修改内存使用和性能调节,SGA
查看>>
SQL语言基础
查看>>
对事件处理的错误使用
查看>>
最大熵模型(二)朗格朗日函数
查看>>
深入了解setInterval方法
查看>>
html img Src base64 图片显示
查看>>