加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 1.32 KB
一键复制 编辑 原始数据 按行查看 历史
Archermmt 提交于 2022-01-13 20:54 . add basic parts
import tvm
import torch
import numpy as np
def cast_array(array):
if isinstance(array,tvm.runtime.ndarray.NDArray):
array=array.asnumpy()
elif isinstance(array,torch.Tensor):
array=array.detach().cpu().numpy()
assert isinstance(array,np.ndarray),"Only accept array as numpy.ndarray, get "+str(type(array))
return array
def array_des(array):
type_des=array.__class__.__name__
array=cast_array(array)
return "<{}>[{};{}] max {:g}, min {:g}, sum {:g}".format(
type_des,','.join([str(s) for s in array.shape]),array.dtype.name,
array.max(),array.min(),array.sum())
def array_compare(arrayA,arrayB,nameA="A",nameB="B",error=0.05):
arrayA=cast_array(arrayA)
arrayB=cast_array(arrayB)
if arrayA.dtype!=arrayB.dtype:
print("dtype mismatch between {} and {}".format(arrayA.dtype,arrayB.dtype))
if arrayA.shape!=arrayB.shape:
print("dtype mismatch between {} and {}".format(arrayA.dtype,arrayB.dtype))
diff=(arrayA-arrayB)/(abs(arrayA)+0.0001)
msg="max : {:g}, min :{:g}, sum : {:g}".format(diff.max(),diff.min(),diff.sum())
if abs(diff).max()>error:
print("[FAIL] "+msg)
print("{} : {}".format(nameA,array_des(arrayA)))
print("{} : {}".format(nameB,array_des(arrayB)))
return False
print("[PASS] "+msg)
return True
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化