代码拉取完成,页面将自动刷新
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import layers
from tensorflow.python.keras import initializers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
class SpectralNormalization(layers.Wrapper):
def __init__(self, layer, **kwargs):
super(SpectralNormalization, self).__init__(layer, **kwargs)
def build(self, input_shape):
if not self.layer.built:
self.layer.build(input_shape)
if not hasattr(self.layer, 'kernel'):
raise ValueError(
'`SpectralNormalization` must wrap a layer that'
' contains a `kernel` for weights')
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.u = self.add_weight(
shape=tuple([1, self.w_shape[-1]]),
initializer=initializers.TruncatedNormal(stddev=0.02),
name='sn_u',
trainable=False,
dtype=dtypes.float32)
#self.u = self.add_variable(
# shape=tuple([1, self.w_shape[-1]]),
# initializer=initializers.TruncatedNormal(stddev=0.02),
# name='sn_u',
# trainable=False,
# dtype=dtypes.float32)
super(SpectralNormalization, self).build()
@def_function.function
def call(self, inputs, training=None):
if training==None:
training = K.learning_phase()
if training==True:
self._compute_weights()
output = self.layer(inputs)
return output
def _compute_weights(self):
w_reshaped = array_ops.reshape(self.w, [-1, self.w_shape[-1]])
eps = 1e-12
_u = array_ops.identity(self.u)
_v = math_ops.matmul(_u, array_ops.transpose(w_reshaped))
_v = _v / math_ops.maximum(math_ops.reduce_sum(_v**2)**0.5, eps)
_u = math_ops.matmul(_v, w_reshaped)
_u = _u / math_ops.maximum(math_ops.reduce_sum(_u**2)**0.5, eps)
self.u.assign(_u)
sigma = math_ops.matmul(math_ops.matmul(_v, w_reshaped), array_ops.transpose(_u))
self.layer.kernel.assign(self.w / sigma)
def compute_output_shape(self, input_shape):
return tensor_shape.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。