代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# coding: utf-8
# Created on Mon Oct. 24 15:24:18 2022
# @author: Lu Jian
# Email:janelu@live.cn;
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class PositionalEmbedding(nn.Layer):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (paddle.arange(0.0, demb, 2.0) / demb)).unsqueeze(0)
self.register_buffer('inv_freq', inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = pos_seq.unsqueeze(1) * self.inv_freq
pos_emb = paddle.concat([sinusoid_inp.sin(), sinusoid_inp.cos()], axis=-1)
return pos_emb
class MultiHeadAttention(nn.Layer):
def __init__(self,embed_dim,num_heads,dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
weight_attr=None,bias_attr=None,**kwargs):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.n_head = num_heads
self.dropatt = nn.Dropout(dropatt, mode="upscale_in_train")
self.d_head = embed_dim // num_heads
assert self.d_head * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.q_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.scale = self.d_head ** -0.5
def compute_kv(self, key, value):
k = paddle.transpose(paddle.reshape(self.k_proj(key), [0, 0, self.num_heads, self.head_dim]), [0, 2, 3, 1])
v = paddle.transpose(paddle.reshape(self.v_proj(value), [0, 0, self.num_heads, self.head_dim]), [0, 2, 1, 3])
return k, v
def attention(self,q,k,v,attn_mask=0):
product = paddle.matmul(q,k) * self.scale + attn_mask
weights = self.dropatt(F.softmax(product))
out = paddle.reshape(paddle.transpose(paddle.matmul(weights,v),[0, 2, 1, 3]),[0, 0, -1])
out = self.out_proj(out)
return out,weights
def forward(self, query, key, value, attn_mask= 0, mems=None):
q =paddle.transpose(paddle.reshape(self.q_proj(query), [0, 0, self.num_heads, self.head_dim]), [0, 2, 1, 3])
if mems is not None:
key = torch.cat([mems, key], 1)
value = torch.cat([mems, value], 1)
k,v = self.compute_kv(key,value)
out,weights =self.attention(q,k,v,attn_mask)
return out
def _parallelogram_mask(self, h, w, left=False):
mask = paddle.ones((h, w))
m = min(h, w)
mask[:m,:m] = paddle.triu(mask[:m,:m])
mask[-m:,-m:] = paddle.tril(mask[-m:,-m:])
if left:
return mask
else:
return mask.flip(0)
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = paddle.zeros((x.shape[0], qlen-1, x.shape[2], x.shape[3]),dtype=x.dtype)
else:
zero_pad = paddle.zeros(0,dtype=x.dtype)
if left:
mask = mask.flip(1)
x_padded = paddle.concat([zero_pad, x], axis=1).expand([qlen, -1, -1, -1])
else:
x_padded = paddle.concat([x, zero_pad], axis=1).expand([qlen, -1, -1, -1])
x = x_padded.masked_select(mask[:,:,None,None]).reshape([qlen, klen, x.shape[2], x.shape[3]])
return x
def _rel_shift(self, x, zero_triu=False):
zero_pad = paddle.zeros((x.shape[0], 1, *x.shape[2:]), dtype=x.dtype)
x_padded = paddle.concat([zero_pad, x], axis=1)
x_padded = x_padded.reshape([x.shape[1] + 1, x.shape[0], *x.shape[2:]])
x = x_padded[1:].reshape(x.shape)
if zero_triu:
ones = paddle.ones((x.shape[0], x.shape[1]))
x = x * paddle.tril(ones, x.shape[0] - x.shape[0])[:,:,None,None]
return x
class RelLearnableMultiHeadAttn(MultiHeadAttention):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen, bsz = w.shape[1], w.shape[0]
if mems is not None:
cat = paddle.cat([mems, w], 1)
w_head_q = self.q_proj(cat[:,-qlen:])
w_head_k = self.k_proj(cat)
w_head_v = self.v_proj(cat)
else:
w_head_q = self.q_proj(w)
w_head_k = self.k_proj(w)
w_head_v = self.v_proj(w)
klen = w_head_k.shape[1]
w_head_q = w_head_q.reshape((bsz, self.n_head, qlen, self.d_head))
w_head_k = w_head_k.reshape((bsz, self.n_head, klen, self.d_head))
w_head_v = w_head_v.reshape((bsz, self.n_head, klen, self.d_head))
if klen > r_emb.shape[0]:
r_emb_pad = r_emb[0:1].expand(klen-r_emb.shape[0], -1, -1)
r_emb = paddle.concat([r_emb_pad, r_emb], 0)
r_bias_pad = r_bias[0:1].expand(klen-r_bias.shape[0], -1)
r_bias = paddle.concat([r_bias_pad, r_bias], 0)
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]
#### compute attention score
rw_head_q = w_head_q + r_w_bias[:,None] # qlen x bsz x n_head x d_head
AC = paddle.einsum('bnid,bnjd->bnij', rw_head_q, w_head_k) # qlen x klen x bsz x n_head
B_ = paddle.einsum('bnid,jnd->bnij', w_head_q, r_emb) # qlen x klen x bsz x n_head
D_ = r_bias.T[None, :,None] # 1 x klen x 1 x n_head
BD = self._rel_shift(B_ + D_)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score *= self.scale
#### compute attention probability
if attn_mask is not None and attn_mask.astype("bool").any():
if attn_mask.dim() == 2:
attn_score= paddle.where(attn_mask[:,None,None,:], paddle.zeros([1])-float('inf'),attn_score)
elif attn_mask.dim() == 3:
attn_score= paddle.where(attn_mask[:,None,:,:], paddle.zeros([1])-float('inf'),attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, axis=-1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = paddle.einsum('bnij,bnjd->bind', attn_prob, w_head_v)
attn_vec = attn_vec.reshape([
attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head])
output=self.out_proj(attn_vec)
return output
class RelPartialLearnableMultiHeadAttn(MultiHeadAttention):
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.embed_dim, self.embed_dim, bias_attr =False)
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
qlen,rlen, bsz = w.shape[1],r.shape[0] ,w.shape[0]
if mems is not None:
cat = paddle.concat([mems, w], 1)
w_head_q = self.q_proj(cat[:,-qlen:])
w_head_k = self.k_proj(cat)
w_head_v = self.v_proj(cat)
else:
w_head_q = self.q_proj(w)
w_head_k = self.k_proj(w)
w_head_v = self.v_proj(w)
r_head_k = self.r_net(r)
klen = w_head_k.shape[1]
w_head_q = w_head_q.reshape((bsz, self.n_head, qlen, self.d_head))
w_head_k = w_head_k.reshape((bsz, self.n_head, klen, self.d_head))
w_head_v = w_head_v.reshape((bsz, self.n_head, klen, self.d_head))
r_head_k = r_head_k.reshape([rlen, self.n_head, self.d_head])
#### compute attention score
rw_head_q = w_head_q + r_w_bias.unsqueeze(1) # qlen x bsz x n_head x d_head
AC = paddle.einsum('bnid,bnjd->bnij', rw_head_q, w_head_k) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + r_r_bias.unsqueeze(1)
BD = paddle.einsum('bnid,jnd->bnij', rr_head_q, r_head_k) # qlen x klen x bsz x n_head # 1 x klen x 1 x n_head
BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score *= self.scale
#### compute attention probability
if attn_mask is not None and attn_mask.astype("bool").any():
if attn_mask.dim() == 2:
attn_score= paddle.where(attn_mask[:,None,None,:], paddle.zeros([1])-float('inf'),attn_score)
elif attn_mask.dim() == 3:
attn_score= paddle.where(attn_mask[:,None,:,:], paddle.zeros([1])-float('inf'),attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, axis=-1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = paddle.einsum('bnij,bnjd->bind', attn_prob, w_head_v)
attn_vec = attn_vec.reshape([
attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head])
attn_out = self.out_proj(attn_vec)
return attn_out
class RelLearnableDecoderLayer(nn.Layer):
def __init__(self,d_model,n_head,dim_feedforward,dropout=0.1,dropatt=0.1,activation='GELU',normalize_before=False,**kwargs):
self._config = locals()
self._config.pop("__class__", None)
super(RelLearnableDecoderLayer, self).__init__()
self.normalize_before = normalize_before
self.self_attn = RelLearnableMultiHeadAttn(d_model,n_head,dropatt=dropatt,**kwargs)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.norm1 = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = eval(f'nn.{activation}()')#getattr(F, activation)
self.dropact = nn.Dropout(dropout, mode="upscale_in_train")
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
self.norm2 = nn.LayerNorm(d_model)
self._config.pop("self")
def forward(self,dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
if not self.normalize_before:
residual = dec_inp
tgt = self.self_attn(dec_inp,r_emb,r_w_bias,r_bias,
attn_mask=dec_attn_mask,
mems=mems)
tgt = residual + self.dropout1(tgt)
tgt = self.norm1(tgt)
residual = tgt
tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout2(tgt)
tgt = self.norm2(tgt)
return tgt
residual = dec_inp
tgt = self.norm1(dec_inp)
tgt = self.self_attn(tgt,r_emb,r_w_bias,r_bias,
attn_mask=dec_attn_mask,
mems=mems)
tgt = residual + self.dropout1(tgt)
residual = tgt
tgt = self.norm2(tgt)
tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout2(tgt)
return tgt
class RelPartialLearnableDecoderLayer(nn.Layer):
def __init__(self,d_model,n_head,dim_feedforward,dropout=0.1,dropatt=0.1,activation='GELU',normalize_before=False,**kwargs):
self._config = locals()
self._config.pop("__class__", None)
super(RelPartialLearnableDecoderLayer, self).__init__()
self.normalize_before = normalize_before
self.self_attn = RelPartialLearnableMultiHeadAttn(d_model,n_head,dropatt=dropatt,**kwargs)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.norm1 = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = eval(f'nn.{activation}()')#getattr(F, activation)
self.dropact = nn.Dropout(dropout, mode="upscale_in_train")
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
self.norm2 = nn.LayerNorm(d_model)
self._config.pop("self")
def forward(self,dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
if not self.normalize_before:
residual = dec_inp
tgt = self.self_attn(dec_inp, r, r_w_bias, r_r_bias,
attn_mask=dec_attn_mask,
mems=mems)
tgt = residual + self.dropout1(tgt)
tgt = self.norm1(tgt)
residual = tgt
tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout2(tgt)
tgt = self.norm2(tgt)
return tgt
residual = dec_inp
tgt = self.norm1(dec_inp)
tgt = self.self_attn(tgt, r, r_w_bias, r_r_bias,
attn_mask=dec_attn_mask,
mems=mems)
tgt = residual + self.dropout1(tgt)
residual = tgt
tgt = self.norm2(tgt)
tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout2(tgt)
return tgt
class WordEmbedding(nn.Layer):
def __init__(self, n_token, d_embed,pad_id=0,dropout=0.1,**kwargs):
super(WordEmbedding, self).__init__()
self.n_token = n_token
self.d_embed = d_embed
self.emb_scale = d_embed ** 0.5
self.word_embeddings = nn.Embedding(n_token, d_embed, padding_idx=pad_id)
self.layer_norm = nn.LayerNorm(d_embed)
def forward(self, inp):
embeddings = self.word_embeddings(inp)
embeddings = self.layer_norm(embeddings)
return embeddings
class ERNIE_XL(nn.Layer):
def __init__(self, n_token, n_layer,n_head, d_model, d_inner,attn_type=0,
dropout=0.1, dropatt=0.1, normalize_before=False,
tgt_len=0, ext_len=0, mem_len=0,**kwargs):
super(ERNIE_XL, self).__init__()
self.n_token = n_token
d_embed = d_model
self.d_embed = d_embed
self.d_model = d_model
self.n_head = n_head
self.d_head = d_model//n_head
self.word_emb = WordEmbedding(n_token, d_embed, dropout=dropout,**kwargs)
self.drop = nn.Dropout(dropout)
self.n_layer = n_layer
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
self.max_klen = tgt_len + ext_len + mem_len
self.attn_type = attn_type
self.layers = nn.LayerList()
if attn_type == 0: # the default attention
for i in range(n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
d_model,n_head,d_inner, dropout,dropatt=dropatt,
normalize_before=normalize_before,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len)
)
elif attn_type == 1: # learnable embeddings
for i in range(n_layer):
self.layers.append(
RelLearnableDecoderLayer(
d_model,n_head,d_inner,dropout,dropatt=dropatt,
normalize_before=normalize_before,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
)
)
elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer):
self.layers.append(
DecoderLayer(
d_model,n_head, d_inner, dropout,
dropatt=dropatt, normalize_before=normalize_before)
)
# self.sample_softmax = sample_softmax
# # use sampled softmax
# if sample_softmax > 0:
# self.out_layer = nn.Linear(d_model, n_token)
# if tie_weight:
# self.out_layer.weight = self.word_emb.weight
# self.tie_weight = tie_weight
# self.sampler = LogUniformSampler(n_token, sample_softmax)
# # use adaptive softmax (including standard softmax)
# else:
# self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
# cutoffs, div_val=div_val)
# if tie_weight:
# for i in range(len(self.crit.out_layers)):
# self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
# if tie_projs:
# for i, tie_proj in enumerate(tie_projs):
# if tie_proj and div_val == 1 and d_model != d_embed:
# self.crit.out_projs[i] = self.word_emb.emb_projs[0]
# elif tie_proj and div_val != 1:
# self.crit.out_projs[i] = self.word_emb.emb_projs[i]
# self.same_length = same_length
# self.clamp_len = clamp_len
self._create_params()
# def backward_compatible(self):
# self.sample_softmax = -1
def _create_params(self):
if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model)
self.r_w_bias = paddle.create_parameter((self.n_head, self.d_head),"float32")
self.r_r_bias = paddle.create_parameter((self.n_head, self.d_head),"float32")
elif self.attn_type == 1: # learnable
self.r_emb = paddle.create_parameter((
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_w_bias = paddle.create_parameter((
self.n_layer, self.n_head, self.d_head))
self.r_bias = paddle.create_parameter((
self.n_layer, self.max_klen, self.n_head))
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = paddle.create_parameter((
self.n_layer, self.max_klen, self.n_head, self.d_head))
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
def init_mems(self):
if self.mem_len > 0:
mems = []
for i in range(self.n_layer+1):
empty = paddle.empty([0,0,0],"float32")
mems.append(empty)
return mems
else:
return None
def _update_mems(self, hids, mems, qlen, mlen):
# does not deal with None
if mems is None: return None
# mems is not None
assert len(hids) == len(mems), 'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
cat = paddle.concat([mems[i], hids[i]],1)
new_mems.append(cat[:,beg_idx:end_idx])
return new_mems
def _forward(self, dec_inp, mems=None):
bsz ,qlen= dec_inp.shape
word_emb = self.word_emb(dec_inp)
mlen = mems[0].shape[1] if mems is not None else 0
klen = mlen + qlen
hids = []
if self.attn_type == 0: # default
pos_seq = paddle.arange(klen-1, -1, -1.0, dtype=word_emb.dtype)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 2: # absolute
pos_seq = paddle.arange(klen - 1, -1, -1.0, dtype=word_emb.dtype)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:])
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
elif self.attn_type == 3:
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand([mlen-cur_size, -1, -1])
cur_emb = paddle.concat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.reshape([mlen, 1, -1])
core_out += self.r_emb[i][-qlen:].reshape([qlen, 1, -1])
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems
def forward(self, inp, mems = None):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if not mems: mems = self.init_mems()
block_num = inp.shape[1]//self.mem_len
for step in range(block_num):
s=step*self.mem_len
e=s+self.mem_len
data=inp[:,s:e]
hidden, mems = self._forward(data, mems=mems)
return hidden
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。