代码拉取完成,页面将自动刷新
import math
import torch
from torch.optim import Optimizer
# Spherical Optimizer Class
# Uses the first two dimensions as batch information
# Optimizes over the surface of a sphere using the initial radius throughout
#
# Example Usage:
# opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01)
class SphericalOptimizer(Optimizer):
def __init__(self, optimizer, params, **kwargs):
self.opt = optimizer(params, **kwargs)
self.params = params
with torch.no_grad():
self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}
@torch.no_grad()
def step(self, closure=None):
loss = self.opt.step(closure)
for param in self.params:
param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
param.mul_(self.radii[param])
return loss
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。