代码拉取完成,页面将自动刷新
import platform
import os
import time
from threading import Thread
from rich.text import Text
from rich.live import Live
from model.infer import ChatBot
from config import InferConfig
infer_config = InferConfig()
chat_bot = ChatBot(infer_config=infer_config)
clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear'
welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
print(welcome_txt)
def build_prompt(history: list[list[str]]) -> str:
prompt = welcome_txt
for query, response in history:
prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
return prompt
STOP_CIRCLE: bool=False
def circle_print(total_time: int=60) -> None:
global STOP_CIRCLE
'''非stream chat打印忙碌状态
'''
list_circle = ["\\", "|", "/", "—"]
for i in range(total_time * 4):
time.sleep(0.25)
print("\r{}".format(list_circle[i % 4]), end="", flush=True)
if STOP_CIRCLE: break
print("\r", end='', flush=True)
def chat(stream: bool=True) -> None:
global STOP_CIRCLE
history = []
turn_count = 0
while True:
print('\r\033[0;33;40m用户:\033[0m', end='', flush=True)
input_txt = input()
if len(input_txt) == 0:
print('请输入问题')
continue
# 退出
if input_txt.lower() == 'exit':
break
# 清屏
if input_txt.lower() == 'cls':
history = []
turn_count = 0
os.system(clear_cmd)
print(welcome_txt)
continue
if not stream:
STOP_CIRCLE = False
thread = Thread(target=circle_print)
thread.start()
outs = chat_bot.chat(input_txt)
STOP_CIRCLE = True
thread.join()
print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='')
continue
history.append([input_txt, ''])
stream_txt = []
streamer = chat_bot.stream_chat(input_txt)
rich_text = Text()
print("\r\033[0;32;40mChatBot:\033[0m\n", end='')
with Live(rich_text, refresh_per_second=15) as live:
for i, word in enumerate(streamer):
rich_text.append(word)
stream_txt.append(word)
stream_txt = ''.join(stream_txt)
if len(stream_txt) == 0:
stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
history[turn_count][1] = stream_txt
os.system(clear_cmd)
print(build_prompt(history), flush=True)
turn_count += 1
if __name__ == '__main__':
chat(stream=True)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。