加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
readme 1.53 KB
一键复制 编辑 原始数据 按行查看 历史
liuqiyuan 提交于 2022-04-28 12:43 . add
###############################################
############ 版本:demo-0.1 ###############
##功能:比较pytorch与mindspore网络模型间的精度##
###############################################
主要依赖
torch 1.11.0
mindspore-ascend 1.6.1
npu环境
参数说明
class compare()
--ptmodule # pytorch模块实例
--msmodule # minspore模块实例
--module_type # 取值"net"或"loss",当输入模块含有权重时,取"net",当模块没有权重如loss模块时,取“loss”
--input_shape # 模块输入的shape,工具会根据该shape随机生成模块输入,当输入个数为1时,取[[n1,n2,n3,...]], 当个数为2时,取[[n1,n2,n3,...][y1,y2,y3,...]]
--init_mode # 取值为"ones"或"random","ones"表示会把模块中的所有权重初始化为1,"random"表示会把模块中的所有权重初始化为任意一个随机数
--input_num # 取值1或2,表示输出个数,需要与input_shape同步
--print_result # bool值,True打印两个模块的输出,False则不打印
使用流程
1. 将所需对比的两个模块的py文件放入module路径
2. 在run_compare.py文件中将上述module导入
3. 将两个模块实例化为PtNet(mindspore)与MsNet(pytorch)
4. 调整compare中的参数
5. 运行run_compare.py
输出说明
1.similarity表示两模块输出结果数值上的相似程度,以pytorch输出作为基线,similarity=1表示完全相同
2.cosine similarity表示两模块输出的余弦相似度,cosine similarity=1表示完全相似
3.variance表示方差,variance=0表示完全相同
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化