加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
openai_api_server.py 12.26 KB
一键复制 编辑 原始数据 按行查看 历史
李玉宝 提交于 2024-08-29 22:48 . 暂时去掉训练
import os
import time
# from asyncio.log import logger
import re
import uvicorn
import gc
import json
import torch
import random
import string
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
from sse_starlette.sse import EventSourceResponse
from train import train
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = os.environ.get('MODEL_PATH', 'hf-models/glm-4-9b-chat')
MAX_MODEL_LENGTH = 8192
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def generate_id(prefix: str, k=29) -> str:
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
return f"{prefix}{suffix}"
class ModelCard(BaseModel):
id: str = ""
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = ["glm-4"]
class FunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChoiceDeltaToolCallFunction(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionMessageToolCall(BaseModel):
index: Optional[int] = 0
id: Optional[str] = None
function: FunctionCall
type: Optional[Literal["function"]] = 'function'
class ChatMessage(BaseModel):
# “function” 字段解释:
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
role: Literal["user", "assistant", "system", "tool"]
content: Optional[str] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "tool_calls"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: Optional[str] = Field(
default_factory=lambda: generate_id('chatcmpl-', 29))
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
system_fingerprint: Optional[str] = Field(
default_factory=lambda: generate_id('fp_', 9))
usage: Optional[UsageInfo] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = None
repetition_penalty: Optional[float] = 1.1
@torch.inference_mode()
async def generate_stream_glm4(params):
messages = params["messages"]
tools = params["tools"]
tool_choice = params["tool_choice"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 8192))
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False)
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"repetition_penalty": repetition_penalty,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"max_tokens": max_new_tokens,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
async for output in engine.generate(inputs, sampling_params, f"{time.time()}"):
output_len = len(output.outputs[0].token_ids)
input_len = len(output.prompt_token_ids)
ret = {
"text": output.outputs[0].text,
"usage": {
"prompt_tokens": input_len,
"completion_tokens": output_len,
"total_tokens": output_len + input_len
},
"finish_reason": output.outputs[0].finish_reason,
}
yield ret
gc.collect()
torch.cuda.empty_cache()
def process_messages(messages, tools=None, tool_choice="none"):
_messages = messages
processed_messages = []
msg_has_sys = False
for m in _messages:
role, content, func_call = m.role, m.content, m.function_call
if role == "assistant":
for response in content.split("\n"):
if "\n" in response:
metadata, sub_content = response.split(
"\n", maxsplit=1)
else:
metadata, sub_content = "", response
processed_messages.append(
{
"role": role,
"metadata": metadata,
"content": sub_content.strip()
}
)
else:
if role == "system" and msg_has_sys:
msg_has_sys = False
continue
processed_messages.append({"role": role, "content": content})
if not tools or tool_choice == "none":
for m in _messages:
if m.role == 'system':
processed_messages.insert(
0, {"role": m.role, "content": m.content})
break
return processed_messages
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="glm-4")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
tool_choice=request.tool_choice,
)
# logger.debug(f"==== request ====\n{gen_params}")
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator)
if output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
# logger.debug(f"First result output:\n{output}")
response = ""
async for response in generate_stream_glm4(gen_params):
pass
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
usage = UsageInfo()
function_call, finish_reason = None, "stop"
tool_calls = None
message = ChatMessage(
role="assistant",
content=None if tool_calls else response["text"],
function_call=None,
tool_calls=tool_calls,
)
print(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(
model=request.model,
choices=[choice_data],
object="chat.completion",
usage=usage
)
async def predict_stream(model_id, gen_params):
output = ""
is_function_call = False
has_send_first_chunk = False
created_time = int(time.time())
function_name = None
response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9)
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
lines = output.strip().split("\n")
finish_reason = new_response.get("finish_reason", None)
if not has_send_first_chunk:
message = DeltaMessage(
content="",
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True
message = DeltaMessage(
content=delta_text,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH, trust_remote_code=True)
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
tokenizer=MODEL_PATH,
# 如果你有多张显卡,可以在这里设置成你的显卡数量
tensor_parallel_size=1,
dtype="float16",
trust_remote_code=True,
# 占用显存的比例,请根据你的显卡显存大小设置合适的值,例如,如果你的显卡有80G,您只想使用24G,请按照24/80=0.3设置
gpu_memory_utilization=0.9,
enforce_eager=True,
worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True,
max_model_len=MAX_MODEL_LENGTH,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# jupyter 使用以下方法解决 asyncio.run() cannot be called from a running event loop,jupyter 已内置事件循环
# config = uvicorn.Config(app, host='0.0.0.0', port=8000, workers=1)
# server = uvicorn.Server(config)
# await server.serve()
# 正常运行
print("Starting server...")
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化