加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
cli_demo.py 2.80 KB
一键复制 编辑 原始数据 按行查看 历史
Charent 提交于 2023-12-29 14:34 . change model load code
import platform
import os
import time
from threading import Thread
from rich.text import Text
from rich.live import Live
from model.infer import ChatBot
from config import InferConfig
infer_config = InferConfig()
chat_bot = ChatBot(infer_config=infer_config)
clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear'
welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
print(welcome_txt)
def build_prompt(history: list[list[str]]) -> str:
prompt = welcome_txt
for query, response in history:
prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
return prompt
STOP_CIRCLE: bool=False
def circle_print(total_time: int=60) -> None:
global STOP_CIRCLE
'''非stream chat打印忙碌状态
'''
list_circle = ["\\", "|", "/", "—"]
for i in range(total_time * 4):
time.sleep(0.25)
print("\r{}".format(list_circle[i % 4]), end="", flush=True)
if STOP_CIRCLE: break
print("\r", end='', flush=True)
def chat(stream: bool=True) -> None:
global STOP_CIRCLE
history = []
turn_count = 0
while True:
print('\r\033[0;33;40m用户:\033[0m', end='', flush=True)
input_txt = input()
if len(input_txt) == 0:
print('请输入问题')
continue
# 退出
if input_txt.lower() == 'exit':
break
# 清屏
if input_txt.lower() == 'cls':
history = []
turn_count = 0
os.system(clear_cmd)
print(welcome_txt)
continue
if not stream:
STOP_CIRCLE = False
thread = Thread(target=circle_print)
thread.start()
outs = chat_bot.chat(input_txt)
STOP_CIRCLE = True
thread.join()
print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='')
continue
history.append([input_txt, ''])
stream_txt = []
streamer = chat_bot.stream_chat(input_txt)
rich_text = Text()
print("\r\033[0;32;40mChatBot:\033[0m\n", end='')
with Live(rich_text, refresh_per_second=15) as live:
for i, word in enumerate(streamer):
rich_text.append(word)
stream_txt.append(word)
stream_txt = ''.join(stream_txt)
if len(stream_txt) == 0:
stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
history[turn_count][1] = stream_txt
os.system(clear_cmd)
print(build_prompt(history), flush=True)
turn_count += 1
if __name__ == '__main__':
chat(stream=True)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化