加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 7.55 KB
一键复制 编辑 原始数据 按行查看 历史
FiveKernelsMooncake 提交于 2021-09-29 16:37 . init
from __future__ import absolute_import, division, print_function
import sys
import random
import pickle
import logging
import logging.handlers
import numpy as np
import scipy.sparse as sp
from sklearn.feature_extraction.text import TfidfTransformer
import torch
import os
import warnings
warnings.filterwarnings('ignore')
#'littleele': '/mnt/ssd/zjyang/KAPR/OnlyProduct/data/littleele',,
data_root='../kbs_code/data'
tmp_root='../kbs_code/tmp'
# COMP & also_bought bought_together
comp_train='/comp_p_p_train.pkl'
comp_test='/comp_p_p_test.pkl'
# SUB & also_viewed
sub_train='/sub_p_p_train.pkl'
sub_test='/sub_p_p_test.pkl'
######## Test for new data
DATASET_DIR={
'ele': os.path.join(data_root,'ele'),
}
# Model result directors
TMP_DIR={
'ele': os.path.join(tmp_root,'ele'),
}
#label files:
LABELS={
'ele': (DATASET_DIR['ele'] + sub_train, DATASET_DIR['ele'] + sub_test,DATASET_DIR['ele'] + comp_train, DATASET_DIR['ele'] + comp_test),
}
# Entities
PRODUCT='product'
WORD='word'
BRAND='brand'
CATEGORY='category'
USER='user'
#Relations
DESCRIBED_AS = 'described_as'
PRODUCED_BY = 'produced_by'
BELONG_TO = 'belong_to'
#BELONGS_TO = 'belongs_to'
#ALSO_BOUGHT='also_bought'
#ALSO_VIEWED='also_viewed'
COMP = 'comp'
SUB = 'sub'
SELF_LOOP = 'self_loop' # only for kg env
PURCHASE='purchase'
WORD_RELATED='word_related'
BRAND_RELATED='brand_related'
CATEGORY_RELATED='category_related'
# Build the graph
KG_RELATION = {
WORD: {
DESCRIBED_AS: PRODUCT,
#WORD_RELATED:WORD,
},
PRODUCT: {
DESCRIBED_AS: WORD,
PRODUCED_BY: BRAND,
BELONG_TO: CATEGORY,
COMP: PRODUCT,
SUB: PRODUCT,
PURCHASE: USER,
},
BRAND: {
PRODUCED_BY: PRODUCT,
#BRAND_RELATED:BRAND,
},
CATEGORY: {
BELONG_TO: PRODUCT,
#CATEGORY_RELATED:CATEGORY,
},
USER: {
PURCHASE: PRODUCT,
}
}
ACTION_PRUNED={
#length 1
'(\'self_loop\',)':['produced_by', 'belong_to', 'comp', 'sub','purchase','described_as'],
#length 2
'(\'self_loop\', \'sub\')':['comp', 'sub'],
'(\'self_loop\', \'comp\')':['comp', 'sub'],
'(\'self_loop\', \'produced_by\')':['produced_by'],
'(\'self_loop\', \'belong_to\')':['belong_to'],
'(\'self_loop\', \'described_as\')':['described_as'],
'(\'self_loop\', \'purchase\')':['purchase'],
#lengrh 3
}
PATH_PATTERN_SUB={
#length=2
21:((None,PRODUCT),(COMP,PRODUCT)),
22:((None,PRODUCT),(SUB,PRODUCT)),
#length=3
31:((None,PRODUCT),(DESCRIBED_AS,WORD),(DESCRIBED_AS,PRODUCT)),
32:((None,PRODUCT),(PRODUCED_BY,BRAND),(PRODUCED_BY,PRODUCT)),
33:((None,PRODUCT),(BELONG_TO,CATEGORY),(BELONG_TO,PRODUCT)),
34:((None,PRODUCT),(SUB,PRODUCT),(SUB,PRODUCT)),
35:((None,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT)),
36:((None,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT)),
37:((None,PRODUCT),(PURCHASE,USER),(PURCHASE,PRODUCT)),
#Product Relation
416:((None,PRODUCT),(SUB,PRODUCT),(SUB,PRODUCT),(SUB,PRODUCT)),
417:((None,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT)),
418:((None,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT)),
419:((None,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT)),
420:((None,PRODUCT),(SUB,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT)),
421:((None,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT)),
}
PATH_PATTERN_COMP={
#length=2
21:((None,PRODUCT),(COMP,PRODUCT)),
22:((None,PRODUCT),(SUB,PRODUCT)),
#length=3
31:((None,PRODUCT),(DESCRIBED_AS,WORD),(DESCRIBED_AS,PRODUCT)),
32:((None,PRODUCT),(PRODUCED_BY,BRAND),(PRODUCED_BY,PRODUCT)),
33:((None,PRODUCT),(BELONG_TO,CATEGORY),(BELONG_TO,PRODUCT)),
34:((None,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT)),
35:((None,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT)),
36:((None,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT)),
37:((None,PRODUCT),(PURCHASE,USER),(PURCHASE,PRODUCT)),
#Product Relation
416:((None,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT)),
417:((None,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT)),
418:((None,PRODUCT),(SUB,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT)),
419:((None,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT),(SUB,PRODUCT)),
420:((None,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT),(COMP,PRODUCT)),
421:((None,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT),(SUB,PRODUCT)),
422:((None,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT),(COMP,PRODUCT)),
}
def get_entities():
return list(KG_RELATION.keys())
def get_relations(entity_head):
return list(KG_RELATION[entity_head].keys())
def get_entity_tail(entity_head, relation):
return KG_RELATION[entity_head][relation]
def save_dataset(dataset,dataset_obj,type):
if type=='whole':
dataset_file=TMP_DIR[dataset]+'/dataset_whole.pkl'
else:
#dataset_file=TMP_DIR[dataset]+'/dataset_'+'mask_'+str(type)+'.pkl'
dataset_file=TMP_DIR[dataset]+'/dataset.pkl'
with open(dataset_file,'wb') as f:
pickle.dump(dataset_obj,f)
def load_dataset(dataset,type):
if type=='whole':
dataset_file=TMP_DIR[dataset]+'/dataset_whole.pkl'
else:
#dataset_file=TMP_DIR[dataset]+'/dataset_'+'mask_'+str(type)+'.pkl'
dataset_file=TMP_DIR[dataset]+'/dataset.pkl'
print(dataset_file)
dataset_obj = pickle.load(open((dataset_file),'rb'))
return dataset_obj
def compute_tfidf_fast(vocab, docs):
"""Compute TFIDF scores for all vocabs.
Args:
docs: list of list of integers, e.g. [[0,0,1], [1,2,0,1]]
Returns:
sp.csr_matrix, [num_docs, num_vocab]
"""
# (1) Compute term frequency in each doc.
data, indices, indptr = [], [], [0]
for d in docs:
term_count = {}
for term_idx in d:
if term_idx not in term_count:
term_count[term_idx] = 1
else:
term_count[term_idx] += 1
indices.extend(term_count.keys())
data.extend(term_count.values())
indptr.append(len(indices))
tf = sp.csr_matrix((data, indices, indptr), dtype=int, shape=(len(docs), len(vocab)))
# (2) Compute normalized tfidf for each term/doc.
transformer = TfidfTransformer(smooth_idf=True)
tfidf = transformer.fit_transform(tf)
return tfidf
def save_kg(dataset,kg,type):
kg_file=TMP_DIR[dataset]+'kg_'+str(type)+'.pkl'
pickle.dump(kg,open(kg_file,'wb'))
def load_kg(dataset,type):
kg_file=TMP_DIR[dataset]+'kg_'+type+'.pkl'
print('load kg from:',kg_file)
kg=pickle.load(open(kg_file,'rb'))
return kg
# 保证代码每一次random的值是一样的
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def save_embed(dataset, embed,best_ep,type):
embed_file = '{}/transe_embed_{}_{}.pkl'.format(TMP_DIR[dataset],type,best_ep)
pickle.dump(embed, open(embed_file, 'wb'))
def load_embed(dataset,type):
embed_file = '{}/transe_embed_{}_300.pkl'.format(TMP_DIR[dataset],type)
print('Load embedding from:', embed_file)
embed = pickle.load(open(embed_file, 'rb'))
return embed
def load_labels(dataset,relation, mode='train',type='whole'):
if mode == 'train':
if relation =='SUB':
label_file = LABELS[dataset][0]
if relation =='COMP':
label_file = LABELS[dataset][2]
elif mode == 'test':
if relation =='SUB':
label_file = LABELS[dataset][1]
if relation =='COMP':
label_file = LABELS[dataset][3]
else:
raise Exception('mode should be one of {train, test}.')
products = pickle.load(open(label_file, 'rb'))
return products
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化