代码拉取完成,页面将自动刷新
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import chainlit as cl
DB_FAISS_PATH = "vectorstores/db_faiss"
custom_prompt_template = """Use the following peices of information to answer the user's question. If you dont know the answer, please just say that you don't know the answer, don't try to make up an answer.
Context:{context}
Question: {question}
+
Only return the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
"""
Prompt template for QA retrieval for each vectorstores
"""
prompt = PromptTemplate(template = custom_prompt_template, input_variables = ['context', 'question'])
return prompt
def load_llm():
llm = CTransformers(
model = "llama-2-7b-chat.ggmlv3.q8_0.bin", model_type = "llama",
max_new_tokens = 512,
temperature = 0.5,
)
return llm
def retrieval_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(
llm = llm,
chain_type = "stuff",
retriever = db.as_retriever(search_kwargs = {'k': 2}),
return_source_documents = True,
chain_type_kwargs= {'prompt' : prompt}
)
return qa_chain
def qa_bot():
embeddings = HuggingFaceEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2', model_kwargs = {'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
llm = load_llm()
qa_prompt = set_custom_prompt()
qa =retrieval_qa_chain(llm, qa_prompt, db)
return qa
def final_result(query):
qa_result = qa_bot()
response = qa_result({'query':query})
return response
#Chainlit
@cl.on_chat_start
async def start():
chain = qa_bot()
msg = cl.Message(content = 'Starting the bot....')
await msg.send()
msg.content = "Hello, welcome to the PillboxGPT. Write your query below."
await msg.update()
cl.user_session.set('chain', chain)
@cl.on_message
async def main(message: cl.Message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message.content, callbacks=[cb])
answer = res["result"]
sources = res["source_documents"]
if sources:
answer += f"\nSources:" + str(sources)
else:
answer += "\nNo sources found"
await cl.Message(content=answer).send()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。