加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
optimizer.py 881 Bytes
一键复制 编辑 原始数据 按行查看 历史
Ishan Misra 提交于 2021-09-15 05:43 . Initial commit
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
def build_optimizer(args, model):
params_with_decay = []
params_without_decay = []
for name, param in model.named_parameters():
if param.requires_grad is False:
continue
if args.filter_biases_wd and (len(param.shape) == 1 or name.endswith("bias")):
params_without_decay.append(param)
else:
params_with_decay.append(param)
if args.filter_biases_wd:
param_groups = [
{"params": params_without_decay, "weight_decay": 0.0},
{"params": params_with_decay, "weight_decay": args.weight_decay},
]
else:
param_groups = [
{"params": params_with_decay, "weight_decay": args.weight_decay},
]
optimizer = torch.optim.AdamW(param_groups, lr=args.base_lr)
return optimizer
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化