代码拉取完成,页面将自动刷新
import argparse
import onnx
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description='print onnx info')
parser.add_argument('model', type=str, default='output.onnx',
help='onnx file')
parser.add_argument('-w', '--show-weights', action='store_true',
help='only show weight shapes')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = onnx.load(args.model)
graph = model.graph
tensor_shapes = {}
if graph.value_info:
for node in graph.value_info:
shape = [x.dim_value for x in node.type.tensor_type.shape.dim]
tensor_shapes[node.name] = shape
if args.show_weights:
for node in graph.initializer:
if node.name.endswith("weight"):
print("{}: {}".format(node.name, node.dims))
exit(0)
for node in graph.initializer:
for name in ["name", "dims"]:
print("{}: {}".format(name, getattr(node, name)))
dtype = getattr(node, "data_type")
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[dtype]
print("dtype: ", dtype)
print("data: ", np.frombuffer(node.raw_data, dtype=dtype).flatten()[:10])
print("---------------------------")
for node in graph.node:
for out in node.output:
if out in tensor_shapes and tensor_shapes[out]:
print("output shape: ", tensor_shapes[out])
print(node)
print("---------------------------")
print("inputs:")
for node in graph.input:
print(node.name, [x.dim_value for x in node.type.tensor_type.shape.dim])
print("---------------------------")
print("\noutputs:")
for node in graph.output:
print(node.name, [x.dim_value for x in node.type.tensor_type.shape.dim])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。