加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test.py 2.96 KB
一键复制 编辑 原始数据 按行查看 历史
bode135 提交于 2024-08-17 17:38 . 1
import torch
from PIL import Image
from modelscope import AutoModelForCausalLM, AutoTokenizer
from modelscope import snapshot_download
import os
from bdtime import tt
cache_dir = os.path.join("..", "models")
os.makedirs(cache_dir, exist_ok=True)
local_cache_dir = os.path.join(cache_dir, 'local')
os.makedirs(local_cache_dir, exist_ok=True)
device = "cuda"
# model_id = "ZhipuAI/glm-4v-9b"
# tokenizer = AutoTokenizer.from_pretrained(
# model_id,
# trust_remote_code=True,
# cache_dir=cache_dir,
# local_files_only=True,
# )
model_name = "ZhipuAI/glm-4v-9b"
model_dir = snapshot_download(model_name, cache_dir=cache_dir, local_files_only=True)
trust_remote_code = True
# trust_remote_code = False
local_model_dir = os.path.join(local_cache_dir, model_name)
os.makedirs(local_model_dir, exist_ok=True)
local_tokenizer_path = os.path.join(local_model_dir, 'tokenizer')
local_model_path = os.path.join(local_model_dir, 'model')
print(f'------- os.path.exists(local_model_path): {os.path.exists(local_model_path)}')
if not os.path.exists(local_model_path):
local_tokenizer_path = model_dir
local_model_path = model_dir
else:
print(f'--- 从本地加载模型! local_model_dir: {local_model_dir}')
tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path, trust_remote_code=trust_remote_code)
model = AutoModelForCausalLM.from_pretrained(
local_model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code
).to(device).eval()
if not os.path.exists(local_model_path):
print(f'------ 保存`tokenizer`和`model`到: {local_model_dir}')
tokenizer.save_pretrained(local_tokenizer_path)
model.save_pretained(local_model_path)
# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()
query = '描述这张图片'
img_file_path = 'images/cat.png'
image = Image.open(img_file_path).convert('RGB')
print(f'--- type(image): {type(image)}')
inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
add_generation_prompt=True, tokenize=True, return_tensors="pt",
return_dict=True) # chat mode
inputs = inputs.to(device)
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
tt.__init__()
print('--- start generate')
run_times = 5
with torch.no_grad():
from bdtime.with_timer import with_timer
with with_timer('测试', tt) as wt:
# for i in range(10):
# tt.sleep(0.3)
# if i % 5 == 0:
# wt.show(f"第{i}次的loss: {i * 2 / 5}")
for i in range(run_times):
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
if i == 0:
print("*** ouput:", tokenizer.decode(outputs[0]))
wt.show(f"第{i}次", reset_cost=True)
print(f'--- total_cost_time: {tt.now()}, mean_cost_time: {tt.now() / run_times : .3f}')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化