代码拉取完成,页面将自动刷新
###############################################
############ 版本:demo-0.1 ###############
##功能:比较pytorch与mindspore网络模型间的精度##
###############################################
主要依赖
torch 1.11.0
mindspore-ascend 1.6.1
npu环境
参数说明
class compare()
--ptmodule # pytorch模块实例
--msmodule # minspore模块实例
--module_type # 取值"net"或"loss",当输入模块含有权重时,取"net",当模块没有权重如loss模块时,取“loss”
--input_shape # 模块输入的shape,工具会根据该shape随机生成模块输入,当输入个数为1时,取[[n1,n2,n3,...]], 当个数为2时,取[[n1,n2,n3,...][y1,y2,y3,...]]
--init_mode # 取值为"ones"或"random","ones"表示会把模块中的所有权重初始化为1,"random"表示会把模块中的所有权重初始化为任意一个随机数
--input_num # 取值1或2,表示输出个数,需要与input_shape同步
--print_result # bool值,True打印两个模块的输出,False则不打印
使用流程
1. 将所需对比的两个模块的py文件放入module路径
2. 在run_compare.py文件中将上述module导入
3. 将两个模块实例化为PtNet(mindspore)与MsNet(pytorch)
4. 调整compare中的参数
5. 运行run_compare.py
输出说明
1.similarity表示两模块输出结果数值上的相似程度,以pytorch输出作为基线,similarity=1表示完全相同
2.cosine similarity表示两模块输出的余弦相似度,cosine similarity=1表示完全相似
3.variance表示方差,variance=0表示完全相同
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。