加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
inference.py 4.37 KB
一键复制 编辑 原始数据 按行查看 历史
林泽毅 提交于 2024-09-10 08:05 . patch: sync api
import os
import cv2
import argparse
import numpy as np
from hivision.error import FaceError
from hivision.utils import hex_to_rgb, resize_image_to_kb, add_background
from hivision import IDCreator
from hivision.creator.layout_calculator import (
generate_layout_photo,
generate_layout_image,
)
from hivision.creator.choose_handler import choose_handler
INFERENCE_TYPE = [
"idphoto",
"human_matting",
"add_background",
"generate_layout_photos",
]
MATTING_MODEL = [
"hivision_modnet",
"modnet_photographic_portrait_matting",
"mnn_hivision_modnet",
"rmbg-1.4",
"birefnet-v1-lite",
]
FACE_DETECT_MODEL = [
"mtcnn",
"face_plusplus",
"retinaface-resnet50",
]
RENDER = [0, 1, 2]
parser = argparse.ArgumentParser(description="HivisionIDPhotos 证件照制作推理程序。")
parser.add_argument(
"-t",
"--type",
help="请求 API 的种类",
choices=INFERENCE_TYPE,
default="idphoto",
)
parser.add_argument("-i", "--input_image_dir", help="输入图像路径", required=True)
parser.add_argument("-o", "--output_image_dir", help="保存图像路径", required=True)
parser.add_argument("--height", help="证件照尺寸-高", default=413)
parser.add_argument("--width", help="证件照尺寸-宽", default=295)
parser.add_argument("-c", "--color", help="证件照背景色", default="638cce")
parser.add_argument(
"-k", "--kb", help="输出照片的 KB 值,仅对换底和制作排版照生效", default=None
)
parser.add_argument(
"--matting_model",
help="抠图模型权重",
default="hivision_modnet",
choices=MATTING_MODEL,
)
parser.add_argument(
"-r",
"--render",
type=int,
help="底色合成的模式,有 0:纯色、1:上下渐变、2:中心渐变 可选",
choices=RENDER,
default=0,
)
parser.add_argument(
"--face_detect_model",
help="人脸检测模型",
default="mtcnn",
choices=FACE_DETECT_MODEL,
)
args = parser.parse_args()
# ------------------- 选择抠图与人脸检测模型 -------------------
creator = IDCreator()
choose_handler(creator, args.matting_model, args.face_detect_model)
root_dir = os.path.dirname(os.path.abspath(__file__))
input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED)
# 如果模式是生成证件照
if args.type == "idphoto":
# 将字符串转为元组
size = (int(args.height), int(args.width))
try:
result = creator(input_image, size=size)
except FaceError:
print("人脸数量不等于 1,请上传单张人脸的图像。")
else:
# 保存标准照
cv2.imwrite(args.output_image_dir, result.standard)
# 保存高清照
file_name, file_extension = os.path.splitext(args.output_image_dir)
new_file_name = file_name + "_hd" + file_extension
cv2.imwrite(new_file_name, result.hd)
# 如果模式是人像抠图
elif args.type == "human_matting":
result = creator(input_image, change_bg_only=True)
cv2.imwrite(args.output_image_dir, result.hd)
# 如果模式是添加背景
elif args.type == "add_background":
render_choice = ["pure_color", "updown_gradient", "center_gradient"]
# 将字符串转为元组
color = hex_to_rgb(args.color)
# 将元祖的 0 和 2 号数字交换
color = (color[2], color[1], color[0])
result_image = add_background(
input_image, bgr=color, mode=render_choice[args.render]
)
result_image = result_image.astype(np.uint8)
if args.kb:
result_image = cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)
result_image = resize_image_to_kb(
result_image, args.output_image_dir, int(args.kb)
)
else:
cv2.imwrite(args.output_image_dir, result_image)
# 如果模式是生成排版照
elif args.type == "generate_layout_photos":
size = (int(args.height), int(args.width))
typography_arr, typography_rotate = generate_layout_photo(
input_height=size[0], input_width=size[1]
)
result_layout_image = generate_layout_image(
input_image,
typography_arr,
typography_rotate,
height=size[0],
width=size[1],
)
if args.kb:
result_layout_image = cv2.cvtColor(result_layout_image, cv2.COLOR_RGB2BGR)
result_layout_image = resize_image_to_kb(
result_layout_image, args.output_image_dir, int(args.kb)
)
else:
cv2.imwrite(args.output_image_dir, result_layout_image)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化