加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
embedding.py 879 Bytes
一键复制 编辑 原始数据 按行查看 历史
Ansel Chen 提交于 2019-09-16 22:45 . init
import torch
import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, dataset, parameter):
super(Embedding, self).__init__()
self.device = parameter['device']
self.ent2id = dataset['ent2id']
self.es = parameter['embed_dim']
num_ent = len(self.ent2id)
self.embedding = nn.Embedding(num_ent, self.es)
if parameter['data_form'] == 'Pre-Train':
self.ent2emb = dataset['ent2emb']
self.embedding.weight.data.copy_(torch.from_numpy(self.ent2emb))
elif parameter['data_form'] in ['In-Train', 'Discard']:
nn.init.xavier_uniform_(self.embedding.weight)
def forward(self, triples):
idx = [[[self.ent2id[t[0]], self.ent2id[t[2]]] for t in batch] for batch in triples]
idx = torch.LongTensor(idx).to(self.device)
return self.embedding(idx)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化