加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
show_onnx.py 1.68 KB
一键复制 编辑 原始数据 按行查看 历史
xiaowei 提交于 2022-10-05 22:57 . add recognition
import argparse
import onnx
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description='print onnx info')
parser.add_argument('model', type=str, default='output.onnx',
help='onnx file')
parser.add_argument('-w', '--show-weights', action='store_true',
help='only show weight shapes')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = onnx.load(args.model)
graph = model.graph
tensor_shapes = {}
if graph.value_info:
for node in graph.value_info:
shape = [x.dim_value for x in node.type.tensor_type.shape.dim]
tensor_shapes[node.name] = shape
if args.show_weights:
for node in graph.initializer:
if node.name.endswith("weight"):
print("{}: {}".format(node.name, node.dims))
exit(0)
for node in graph.initializer:
for name in ["name", "dims"]:
print("{}: {}".format(name, getattr(node, name)))
dtype = getattr(node, "data_type")
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]
print("dtype: ", dtype)
print("data: ", np.frombuffer(node.raw_data, dtype=dtype).flatten()[:10])
print("---------------------------")
for node in graph.node:
for out in node.output:
if out in tensor_shapes and tensor_shapes[out]:
print("output shape: ", tensor_shapes[out])
print(node)
print("---------------------------")
print("inputs:")
for node in graph.input:
print(node.name, [x.dim_value for x in node.type.tensor_type.shape.dim])
print("---------------------------")
print("\noutputs:")
for node in graph.output:
print(node.name, [x.dim_value for x in node.type.tensor_type.shape.dim])
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化