加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
SNConv2D.py 2.48 KB
一键复制 编辑 原始数据 按行查看 历史
yangLiu 提交于 2022-10-11 14:58 . 初始化
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import layers
from tensorflow.python.keras import initializers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
class SpectralNormalization(layers.Wrapper):
def __init__(self, layer, **kwargs):
super(SpectralNormalization, self).__init__(layer, **kwargs)
def build(self, input_shape):
if not self.layer.built:
self.layer.build(input_shape)
if not hasattr(self.layer, 'kernel'):
raise ValueError(
'`SpectralNormalization` must wrap a layer that'
' contains a `kernel` for weights')
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.u = self.add_weight(
shape=tuple([1, self.w_shape[-1]]),
initializer=initializers.TruncatedNormal(stddev=0.02),
name='sn_u',
trainable=False,
dtype=dtypes.float32)
#self.u = self.add_variable(
# shape=tuple([1, self.w_shape[-1]]),
# initializer=initializers.TruncatedNormal(stddev=0.02),
# name='sn_u',
# trainable=False,
# dtype=dtypes.float32)
super(SpectralNormalization, self).build()
@def_function.function
def call(self, inputs, training=None):
if training==None:
training = K.learning_phase()
if training==True:
self._compute_weights()
output = self.layer(inputs)
return output
def _compute_weights(self):
w_reshaped = array_ops.reshape(self.w, [-1, self.w_shape[-1]])
eps = 1e-12
_u = array_ops.identity(self.u)
_v = math_ops.matmul(_u, array_ops.transpose(w_reshaped))
_v = _v / math_ops.maximum(math_ops.reduce_sum(_v**2)**0.5, eps)
_u = math_ops.matmul(_v, w_reshaped)
_u = _u / math_ops.maximum(math_ops.reduce_sum(_u**2)**0.5, eps)
self.u.assign(_u)
sigma = math_ops.matmul(math_ops.matmul(_v, w_reshaped), array_ops.transpose(_u))
self.layer.kernel.assign(self.w / sigma)
def compute_output_shape(self, input_shape):
return tensor_shape.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化