加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 2.52 KB
一键复制 编辑 原始数据 按行查看 历史
Akshayraj Nadar 提交于 2023-12-30 00:43 . first commit
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import chainlit as cl
DB_FAISS_PATH = "vectorstores/db_faiss"
custom_prompt_template = """Use the following peices of information to answer the user's question. If you dont know the answer, please just say that you don't know the answer, don't try to make up an answer.
Context:{context}
Question: {question}
+
Only return the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
"""
Prompt template for QA retrieval for each vectorstores
"""
prompt = PromptTemplate(template = custom_prompt_template, input_variables = ['context', 'question'])
return prompt
def load_llm():
llm = CTransformers(
model = "llama-2-7b-chat.ggmlv3.q8_0.bin", model_type = "llama",
max_new_tokens = 512,
temperature = 0.5,
)
return llm
def retrieval_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(
llm = llm,
chain_type = "stuff",
retriever = db.as_retriever(search_kwargs = {'k': 2}),
return_source_documents = True,
chain_type_kwargs= {'prompt' : prompt}
)
return qa_chain
def qa_bot():
embeddings = HuggingFaceEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2', model_kwargs = {'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
llm = load_llm()
qa_prompt = set_custom_prompt()
qa =retrieval_qa_chain(llm, qa_prompt, db)
return qa
def final_result(query):
qa_result = qa_bot()
response = qa_result({'query':query})
return response
#Chainlit
@cl.on_chat_start
async def start():
chain = qa_bot()
msg = cl.Message(content = 'Starting the bot....')
await msg.send()
msg.content = "Hello, welcome to the PillboxGPT. Write your query below."
await msg.update()
cl.user_session.set('chain', chain)
@cl.on_message
async def main(message: cl.Message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message.content, callbacks=[cb])
answer = res["result"]
sources = res["source_documents"]
if sources:
answer += f"\nSources:" + str(sources)
else:
answer += "\nNo sources found"
await cl.Message(content=answer).send()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化