加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
KgEnv.py 25.70 KB
一键复制 编辑 原始数据 按行查看 历史
FiveKernelsMooncake 提交于 2021-09-29 16:37 . init
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
import os
import sys
from tqdm import tqdm
import pickle
import random
import torch
from datetime import datetime
import math
from utils import *
from icecream import ic
import warnings
warnings.filterwarnings('ignore')
# 用于表征state的类
# 说明state的长度,与过去的几个node,relation相关
class KGState(object):
def __init__(self,embed_size,history_len=1):
self.embed_size=embed_size
self.history_len=history_len
if history_len==0:
self.dim=2*embed_size
elif history_len==1:
self.dim=4*embed_size
elif history_len==2:
self.dim=6*embed_size
else:
raise Exception('history length should be 0/1/2')
def __call__(self,head_node_embed,tail_node_embed,last_node_embed,last_relation_embed,older_node_embed,older_relation_embed):
if self.history_len==0:
return np.concatenate([head_node_embed,tail_node_embed])
elif self.history_len==1:
return np.concatenate([head_node_embed,tail_node_embed,last_node_embed,last_relation_embed])
elif self.history_len==2:
return np.concatenate([head_node_embed,tail_node_embed,last_node_embed,last_relation_embed,older_node_embed,older_relation_embed])
else:
raise Exception('mode should be one of {full, current}')
class BatchKGEnvironment(object):
def __init__(self,dataset_str,max_acts,max_path_len=3,state_history=1,relation=['COMP'],type='whole',train_time=1,delrel=''):
self.max_acts=max_acts
self.act_dim=max_acts+1
self.max_num_nodes=max_path_len+1
# 读取TransE训练好的实体和关系embedding
self.kg=load_kg(dataset_str,type=type)
self.embeds=load_embed(dataset_str,type=type)
self.embed_size=self.embeds[PRODUCT].shape[1] #100
self.embeds[SELF_LOOP]=(np.zeros(self.embed_size),0.0)
self.state_gen=KGState(self.embed_size,history_len=state_history)
self.state_dim=self.state_gen.dim
self.relation=relation[0]
self.train_time=train_time
self.delrel=delrel
if self.delrel!='None':
print('delrel:',self.delrel)
'''
self.train_time==7:目标是找到目标商品的相关商品
reward:用两种关系中分数较大的一项作为商品的reward
剪枝:用两种关系中分数较大的一项作为商品的reward
元路径:少量元路径进行推理
self.train_time==8:目标是找到目标商品的相关商品
reward:用一种关系中分数较大的一项作为商品的reward
剪枝:用两种关系中分数较大的一项作为商品的reward
元路径:少量元路径进行推理
'''
if self.train_time==3 or self.train_time==4 or self.train_time==8:
scorePath=os.path.join('/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp',dataset_str,'KEIM',self.relation,'score_numpy.npy')
self.reward=np.load(scorePath)
print('reward shape',self.reward.shape)
elif self.train_time==5 or self.train_time==7 or self.train_time==9 or self.train_time==13 or self.train_time==70 or self.train_time==71 or self.train_time==72:
scorePath_sub=os.path.join('/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp',dataset_str,'KEIM','SUB','score_numpy.npy')
scorePath_comp=os.path.join('/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp',dataset_str,'KEIM','COMP','score_numpy.npy')
self.reward_sub=np.load(scorePath_sub)
self.reward_comp=np.load(scorePath_comp)
print('reward shape',self.reward_sub.shape)
#for i in self.reward:
#print(i)
#self.train_labels=Train_labels
#Compute product-product score
product_size=len(self.embeds[PRODUCT])
self.p_p_scales=[]
#用判别器得到的模型作为p_p_scales
if self.relation=='COMP':
try:
self.relation_embedding=self.embeds[COMP][0]
except:
self.relation_embedding=self.embeds[SUB][0]
elif self.relation=='SUB':
try:
self.relation_embedding=self.embeds[SUB][0]
except:
self.relation_embedding=self.embeds[COMP][0]
R=5
b_size=math.ceil(product_size/R)
for i in range(R):
remain=product_size-i*b_size
min_s=min(b_size,remain)
start=i*b_size
end=start+min_s
score=np.max(np.dot(self.embeds[PRODUCT][start:end]+self.relation_embedding,self.embeds[PRODUCT].T),axis=1)
self.p_p_scales.extend(score)
print('p_p_scales size:',len(self.p_p_scales),product_size)
#compute path patterns
self.patterns=[]
#根据不同的关系,采取不同的匹配模板
if self.train_time==5:
for pattern_id in PATH_PATTERN_COMP.keys():
pattern=PATH_PATTERN_COMP[pattern_id]
pattern=[SELF_LOOP]+[v[0] for v in pattern[1:]]
self.patterns.append(tuple(pattern))
for pattern_id in PATH_PATTERN_SUB.keys():
pattern=PATH_PATTERN_SUB[pattern_id]
pattern=[SELF_LOOP]+[v[0] for v in pattern[1:]]
if tuple(pattern) not in self.patterns:
self.patterns.append(tuple(pattern))
#其他train的方式都是只读取一种关系的元路径
else:
if self.relation=='COMP':
for pattern_id in PATH_PATTERN_COMP.keys():
pattern=PATH_PATTERN_COMP[pattern_id]
pattern=[SELF_LOOP]+[v[0] for v in pattern[1:]]
self.patterns.append(tuple(pattern))
if self.relation=='SUB':
for pattern_id in PATH_PATTERN_SUB.keys():
pattern=PATH_PATTERN_SUB[pattern_id]
pattern=[SELF_LOOP]+[v[0] for v in pattern[1:]]
self.patterns.append(tuple(pattern))
#current episode information
self._batch_path=None # list of tuples of (relation, node_type, node_id)
self._batch_curr_actions=None #save current valid actions
self._batch_curr_state=None
self._batch_curr_reward=None
# Here only use 1 'done' indicator, since all paths have same length and will finish at the same time.
self._done = False
# 获取指定的pattern
def _has_pattern(self,path):
pattern=tuple([v[0] for v in path])
return pattern in self.patterns
def _batch_has_pattern(self,batch_path):
#print(batch_path)
#print(len(batch_path))
#input()
return [self._has_pattern(path) for path in batch_path]
def _get_actions_all_relation(self,path,done):
#为当前的state返回action
#Compute actions for current node
_,curr_node_type,curr_node_id=path[-1]
actions=[(SELF_LOOP,curr_node_id)] # self-loop must be included.
if done:
return actions
#分析路径模式
pattern=tuple([v[0] for v in path])
if len(pattern)<=2 and str(pattern) in ACTION_PRUNED.keys():
relation_pattern=ACTION_PRUNED[str(pattern)]
if len(pattern)==3:
relation_pattern=['sub','comp']
#Get all possible edges from original KG
#与当前节点相关的边
relation_nodes=self.kg(curr_node_type,curr_node_id)
#可能包含的动作
candidate_acts=[] #list of tuples of (relation, node_type, node_id)
#已经访问过的结点
visited_nodes=set([(v[1],v[2]) for v in path])
for r in relation_nodes:
#去除出度过大的结点
#if len(relation_nodes[r]) > 150:
#continue
if r == self.delrel:
continue
#如果路径模式不在pattern中,则删除
#if r not in relation_pattern:
#continue
next_node_type=KG_RELATION[curr_node_type][r]
next_node_ids=relation_nodes[r]
#filter
next_node_ids=[n for n in next_node_ids if (next_node_type,n) not in visited_nodes]
candidate_acts.extend(zip([r]*len(next_node_ids),next_node_ids))
#如果candidate action集合为空,则return self-loop action
#action的数量小于max_num,则直接return action sets
if len(candidate_acts)<=self.max_acts:
candidate_acts=sorted(candidate_acts,key=lambda x:(x[0],x[1]))
actions.extend(candidate_acts)
#如果actions列表不为0,则删除改自环
if len(actions)>1:
require_remove_element = (SELF_LOOP,curr_node_id)
index = actions.index(require_remove_element)
actions = actions[:index] + actions[index+1:]
return actions
#如果action的动作过多,则需要进行修剪
#每一步action的得分,是根据当前结点和下一步结点的数值来决定的TransE
product_embed=self.embeds[PRODUCT][path[0][-1]]
scores=[]
for r,next_node_id in candidate_acts:
next_node_type=KG_RELATION[curr_node_type][r]
if next_node_type==PRODUCT:
src_embed=product_embed+self.relation_embedding
elif next_node_type==WORD:
src_embed=product_embed+self.embeds[DESCRIBED_AS][0]
elif next_node_type==BRAND:
src_embed=product_embed+self.embeds[PRODUCED_BY][0]
elif next_node_type==CATEGORY:
src_embed=product_embed+self.embeds[BELONG_TO][0]
elif next_node_type==USER:
src_embed=product_embed+self.embeds[PURCHASE][0]
else:
src_embed=product_embed+self.embeds[r][0]
#计算相应的分数
if self.train_time==71 or self.train_time==72:
score1=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
cur_embed=self.embeds[curr_node_type][curr_node_id]+self.embeds[r][0]
score2=np.matmul(cur_embed,self.embeds[next_node_type][next_node_id])
score=0.7*score1+0.3*score2
else:
score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
#使用ScoreNumpy计算得分
if (self.train_time==5 or self.train_time==7 or self.train_time==70 or self.train_time==8 or self.train_time==9 or self.train_time==13 or self.train_time==71 or self.train_time==72) and next_node_type==PRODUCT:
try:
score_sub=np.matmul(product_embed+self.embeds[SUB][0],self.embeds[next_node_type][next_node_id])
except:
score_sub=np.matmul(product_embed+self.embeds[COMP][0],self.embeds[next_node_type][next_node_id])
try:
score_comp=np.matmul(product_embed+self.embeds[COMP][0],self.embeds[next_node_type][next_node_id])
except:
score_comp=np.matmul(product_embed+self.embeds[SUB][0],self.embeds[next_node_type][next_node_id])
if self.train_time==71 or self.train_time==72:
cur_embed=self.embeds[curr_node_type][curr_node_id]+self.embeds[r][0]
score2=np.matmul(cur_embed,self.embeds[next_node_type][next_node_id])
score=0.7*max(score_sub,score_comp)+0.3*score2
else:
score=max(score_sub,score_comp)
scores.append(score)
#选择具有较大分数的动作
candidate_idxs=np.argsort(scores)[-self.max_acts:]
candidate_acts=sorted([candidate_acts[i] for i in candidate_idxs],key=lambda x:(x[0],x[1]))
actions.extend(candidate_acts)
#如果actions列表不为0,则删除改自环
if len(actions)>1:
require_remove_element = (SELF_LOOP,curr_node_id)
index = actions.index(require_remove_element)
actions = actions[:index] + actions[index+1:]
#if ('belong_to',41) in actions:
#print(curr_node_type,curr_node_id)
#input()
return actions
#为当前的state返回action
def _get_actions(self,path,done):
#Compute actions for current node
_,curr_node_type,curr_node_id=path[-1]
#actions=[]
actions=[(SELF_LOOP,curr_node_id)] # self-loop must be included.
if done:
return actions
#Get all possible edges from original KG
#与当前节点相关的边
#print('curr_node_type:',curr_node_type)
#print('curr_node_id:',curr_node_id)
#print('relation_nodes',self.kg(curr_node_type,curr_node_id))
#input()
relation_nodes=self.kg(curr_node_type,curr_node_id)
#可能包含的动作
candidate_acts=[] #list of tuples of (relation, node_type, node_id)
#已经访问过的结点
visited_nodes=set([(v[1],v[2]) for v in path])
CandidateTuple=[]
for r in relation_nodes:
#第一次训练只考虑comp和sub两种关系
#糟糕,这里不会有问题吧
if self.train_time==1 and (r=='purchase' or r=='described_as' or r=='belong_to' or r=='produced_by'):
continue
next_node_type=KG_RELATION[curr_node_type][r]
next_node_ids=relation_nodes[r]
t=(r,next_node_type,next_node_ids)
CandidateTuple.append(t)
#filter
next_node_ids=[n for n in next_node_ids if (next_node_type,n) not in visited_nodes]
candidate_acts.extend(zip([r]*len(next_node_ids),next_node_ids))
#for c in CandidateTuple:
#print(c)
#input()
#for c in candidate_acts:
# print(c)
# input()
CCC=candidate_acts
#for i in candidate_acts:
#print(i)
#input()
#print('candidate_acts')
#input()
#如果candidate action集合为空,则return self-loop action
if len(candidate_acts)==0:
return actions
#action的数量小于max_num,则直接return action sets
if len(candidate_acts)<=self.max_acts:
candidate_acts=sorted(candidate_acts,key=lambda x:(x[0],x[1]))
actions.extend(candidate_acts)
#'''
#如果actions列表不为0,则删除改自环
if len(actions)>1 and self.train_time!=77:
require_remove_element = (SELF_LOOP,curr_node_id)
index = actions.index(require_remove_element)
actions = actions[:index] + actions[index+1:]
return actions
#'''
#如果action的动作过多,则需要进行修剪
#每一步action的得分,是根据目标商品和下一步结点的数值来决定的TransE
#分两步进行剪枝
candidate_product_relation=[] #100
candidate_other_relation=[] #150
product_embed=self.embeds[PRODUCT][path[0][-1]]
scores_p=[]
scores_o=[]
scores=[]
R=[]
NextNodeType=[]
PRODUCT_r=[]
WORD_r=[]
BRAND_r=[]
CATEGORY_r=[]
USER_r=[]
for r,next_node_id in candidate_acts:
t=(r,next_node_id)
R.append(r)
next_node_type=KG_RELATION[curr_node_type][r]
if next_node_type==PRODUCT :
if r=='sub' or r=='comp':
PRODUCT_r.append(r)
src_embed=product_embed+self.relation_embedding
score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
#score=max(score,0)
#保存商品的关系和分数 sub comp
scores_p.append(score)
candidate_product_relation.append(t)
else:
src_embed=product_embed+self.relation_embedding
score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
#保存其他的关系和分数
scores_o.append(score)
candidate_other_relation.append(t)
'''
if r=='sub' or r=='comp':
#src_embed_c=product_embed+self.relation_embedding
#score_c=np.matmul(src_embed_c,self.embeds[next_node_type][next_node_id])
#score=score_c
src_embed_c=product_embed+self.embeds[COMP][0]
score_c=np.matmul(src_embed_c,self.embeds[next_node_type][next_node_id])
src_embed_s=product_embed+self.embeds[SUB][0]
score_s=np.matmul(src_embed_s,self.embeds[next_node_type][next_node_id])
score=max(score_c,score_s)
'''
#else:
#src_embed=product_embed
#score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
elif next_node_type==WORD:
WORD_r.append(r)
src_embed=product_embed+self.embeds[DESCRIBED_AS][0]
score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
#保存其他的关系和分数
scores_o.append(score)
candidate_other_relation.append(t)
elif next_node_type==BRAND:
BRAND_r.append(r)
src_embed=product_embed+self.embeds[PRODUCED_BY][0]
score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
#保存其他的关系和分数
scores_o.append(score)
candidate_other_relation.append(t)
elif next_node_type==CATEGORY:
CATEGORY_r.append(r)
src_embed=product_embed+self.embeds[BELONG_TO][0]
score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
#保存其他的关系和分数
scores_o.append(score)
candidate_other_relation.append(t)
elif next_node_type==USER:
USER_r.append(r)
src_embed=product_embed+self.embeds[PURCHASE][0]
score=np.matmul(src_embed,self.embeds[next_node_type][next_node_id])
#保存其他的关系和分数
scores_o.append(score)
candidate_other_relation.append(t)
#scores.append(score)
#NextNodeType.append((r,next_node_type))
#print('PRODUCT_r',set(PRODUCT_r))
#print('WORD_r',set(WORD_r))
#print('BRAND_r',set(BRAND_r))
#print('CATEGORY_r',set(CATEGORY_r))
#print('USER_r',set(USER_r))
#input()
#print(len(scores),len(candidate_acts))
#print(set(R))
#print(set(NP))
#input()
#选择具有较大分数的动作
#candidate_idxs=np.argsort(scores)[-self.max_acts:]
#print(len(candidate_product_relation))
# comp 和 sub 保留的关系数目 RNum
#其他关系的数目 ONum
RNum=min(self.max_acts,len(candidate_product_relation))
ONum=self.max_acts-RNum
#print(RNum)
candidate_product_relation_idxs=np.argsort(scores_p)[-RNum:]
if ONum>0:
candidate_other_relation_idxs=np.argsort(scores_o)[-(ONum):]
#if 'sub' in R and 'comp' in R and 'described_as' in R and 'belong_to' in R and 'described_as' in R:
# print('+++')
# for i in candidate_idxs:
# print(candidate_acts[i])
# input()
#if 'sub' in R or 'comp' in R:
# print(scores)
# print(len(scores))
#input()
#BeforeCan=candidate_acts[:10]
#for i in BeforeCan:
# print(i)
#input()
#candidate_acts=sorted([candidate_acts[i] for i in candidate_idxs],key=lambda x:(x[0],x[1]))
#NextNodeType=sorted([NextNodeType[i] for i in candidate_idxs],key=lambda x:(x[0],x[1]))
candidate_product_relation=sorted([candidate_product_relation[i] for i in candidate_product_relation_idxs],key=lambda x:(x[0],x[1]))
if ONum>0:
candidate_other_relation=sorted([candidate_other_relation[i] for i in candidate_other_relation_idxs],key=lambda x:(x[0],x[1]))
#for i,t in zip(candidate_acts,NextNodeType):
# print(i,t)
# input()
#s=sorted(scores,reverse=True)
#for i,s in zip(candidate_acts,s):
# if i[0]=='sub' or i[0]=='comp':
# print(i,s)
# input()
#print('++++++++++++++++++++++++++++++++++++++++++++++++++++')
#if 'sub' in R and 'comp' in R and 'described_as' in R and 'belong_to' in R and 'described_as' in R:
# for i,t in zip(candidate_acts,NextNodeType):
# print(i,t)
# input()
#actions.extend(candidate_acts)
#print('$',len(candidate_other_relation))
#print(len(candidate_product_relation),len(candidate_other_relation))
actions.extend(candidate_product_relation)
if ONum>0:
actions.extend(candidate_other_relation)
'''
CCC_r=[]
aaa_R=[]
for i in CCC:
if i[0] not in CCC_r:
CCC_r.append(i[0])
for i in candidate_acts:
if i[0] not in aaa_R:
aaa_R.append(i[0])
print('####################candidates########################')
for i in CCC_r:
print(i)
print('@@@@@@@@@@@@@@@@@@@@@actions@@@@@@@@@@@@@@@@@@@@@@@@')
for i in aaa_R:
print(i)
input()
'''
#如果actions列表不为0,则删除改自环
if len(actions)>1 and self.train_time!=77:
require_remove_element = (SELF_LOOP,curr_node_id)
index = actions.index(require_remove_element)
actions = actions[:index] + actions[index+1:]
#print(len(actions))
return actions
def _batch_get_actions(self,batch_path,done):
#使用全部的关系
#return [self._get_actions_all_relation(path,done) for path in batch_path]
if self.train_time==1:
return [self._get_actions(path,done) for path in batch_path]
else:
return [self._get_actions_all_relation(path,done) for path in batch_path]
# 获取当前的状态
def _get_state(self,path):
"""Return state of numpy vector: [product_embed, curr_node_embed, last_node_embed, last_relation]."""
product_embed=self.embeds[PRODUCT][path[0][-1]]
zero_embed=np.zeros(self.embed_size)
# initalize state
if len(path)==1:
state=self.state_gen(product_embed,product_embed,zero_embed,zero_embed,zero_embed,zero_embed)
return state
older_relation,last_node_type,last_node_id=path[-2]
last_relation,curr_node_type,curr_node_id=path[-1]
curr_node_embed=self.embeds[curr_node_type][curr_node_id]
last_node_embed=self.embeds[last_node_type][last_node_id]
last_relation_embed, _ = self.embeds[last_relation] # this can be self-loop!
if len(path)==2:
state=self.state_gen(product_embed,curr_node_embed,last_node_embed,last_relation_embed,zero_embed,zero_embed)
return state
_,older_node_type,older_node_id=path[-3]
older_node_embed=self.embeds[older_node_type][older_node_id]
older_relation_embed,_=self.embeds[older_relation]
state=self.state_gen(product_embed,curr_node_embed,last_node_embed,last_relation_embed,older_node_embed,older_relation_embed)
return state
def _batch_get_state(self,batch_path):
batch_state=[self._get_state(path) for path in batch_path]
return np.vstack(batch_state)
def _batch_get_cur_node_type(self,batch_path):
batch_cur_node_type=[self._get_cur_node_type(path) for path in batch_path]
return np.vstack(batch_cur_node_type)
def _get_cur_node_type(self,path):
_,curr_node_type,curr_node_id=path[-1]
return curr_node_type
def _batch_get_cur_node_id(self,batch_path):
batch_cur_node_id=[self._get_cur_node_id(path) for path in batch_path]
return np.vstack(batch_cur_node_id)
def _get_cur_node_id(self,path):
_,curr_node_type,curr_node_id=path[-1]
return curr_node_id
def batch_action_embedding(self,batch_cur_node_type,batch_cur_node_id,batch_curr_actions):
batch_act_emb=[self._get_act_emb(Type[0],Id[0],Act) for Type,Id,Act in zip(batch_cur_node_type,batch_cur_node_id,batch_curr_actions)]
batch_act_emb=np.stack(batch_act_emb,axis=0)
return batch_act_emb
def _get_act_emb(self,Type,Id,Act):
action_embedding=[]
#zero_embed=np.zeros(self.embed_size*2, dtype=np.bool)
zero_embed=np.zeros(self.embed_size*2)
flag=0
for i in Act:
relation=i[0]
if relation=='self_loop':
next_node_type=Type
else:
next_node_type=KG_RELATION[Type][relation]
node_ID=i[1]
relation_emb=self.embeds[relation][0]
node_embed=self.embeds[next_node_type][node_ID]
Emb=np.hstack((relation_emb,node_embed))
if flag==0:
action_embedding=Emb
flag=1
else:
action_embedding=np.vstack((action_embedding,Emb))
# padding
zero_num=self.max_acts-len(Act)+1
for i in range(zero_num):
action_embedding=np.vstack((action_embedding,zero_embed))
#action_embedding.flatten()
#action_embedding = action_embedding.reshape(1,-1)
return action_embedding
#获取奖励
def _get_reward(self,path):
#Initial
if len(path)<=1:
return 0.0
#reward只会发生在具有模式的path中
if not self._has_pattern(path):
return 0.0
target_score=0.0
_,curr_node_type,curr_node_id=path[-1]
if curr_node_type==PRODUCT:
pid=path[0][-1]
#如果是同一个结点
if self.train_time ==3 or self.train_time ==4 or self.train_time ==8:
scoreNumReword=self.reward[pid,curr_node_id]
score=scoreNumReword
elif self.train_time ==5 or self.train_time==7 or self.train_time==70 or self.train_time ==9 or self.train_time ==71 :
head_p_vec=self.embeds[PRODUCT][pid]
tail_p_vec=self.embeds[PRODUCT][curr_node_id]
Nmax=1
Nmin=-1
s_1=self.reward_sub[pid,curr_node_id]
s_1=(Nmax-Nmin)*s_1+Nmin
s_2=self.reward_comp[pid,curr_node_id]
s_2=(Nmax-Nmin)*s_2+Nmin
score=max(s_1,s_2)
'''
if self.relation=='SUB':
score=s_1
if self.relation=='COMP':
score=s_2
'''
elif self.train_time ==72 :
head_p_vec=self.embeds[PRODUCT][pid]
tail_p_vec=self.embeds[PRODUCT][curr_node_id]
Nmax=1
Nmin=-1
s_1=self.reward_sub[pid,curr_node_id]
s_1=(Nmax-Nmin)*s_1+Nmin
s_2=self.reward_comp[pid,curr_node_id]
s_2=(Nmax-Nmin)*s_2+Nmin
if self.relation=='SUB':
score=s_1
if self.relation=='COMP':
score=s_2
else:
head_p_vec=self.embeds[PRODUCT][pid]+self.relation_embedding
tail_p_vec=self.embeds[PRODUCT][curr_node_id]
score=np.dot(head_p_vec,tail_p_vec)/self.p_p_scales[pid]
target_score=max(score,0.0)
return target_score
def _batch_get_reward(self,batch_path):
'''
for path in batch_path:
print(path,self._get_reward(path))
input()
'''
batch_reward=[self._get_reward(path) for path in batch_path]
'''
if np.mean(batch_reward) == 0.0:
print(batch_path)
'''
return np.array(batch_reward)
#当达到最大长度的时候则停止搜索
def _is_done(self):
#print(self._done,len(self._batch_path[0]) >=self.max_num_nodes)
return self._done or len(self._batch_path[0]) >=self.max_num_nodes
def reset(self,pids=None):
if pids is None:
all_pids=list(self.kg(PRODUCT).keys())
pids=[random.choice(all_pids)]
self._batch_path=[[(SELF_LOOP,PRODUCT,pid)] for pid in pids]
self._done=False
self._batch_curr_state=self._batch_get_state(self._batch_path)
self._batch_curr_actions=self._batch_get_actions(self._batch_path,self._done)
self._batch_curr_reward=self._batch_get_reward(self._batch_path)
self._cur_node_type=self._batch_get_cur_node_type(self._batch_path)
self._cur_node_id=self._batch_get_cur_node_id(self._batch_path)
return self._batch_curr_state,self._cur_node_type,self._cur_node_id
def batch_step(self,batch_act_idx):
assert len(batch_act_idx)==len(self._batch_path)
for i in range(len(batch_act_idx)):
act_idx=batch_act_idx[i]
_,curr_node_type,curr_node_id=self._batch_path[i][-1]
relation,next_node_id=self._batch_curr_actions[i][act_idx]
if relation ==SELF_LOOP:
next_node_type=curr_node_type
else:
next_node_type=KG_RELATION[curr_node_type][relation]
self._batch_path[i].append((relation,next_node_type,next_node_id))
self._done=self._is_done()
self._batch_curr_state=self._batch_get_state(self._batch_path)
self._batch_curr_actions=self._batch_get_actions(self._batch_path,self._done)
self._batch_curr_reward=self._batch_get_reward(self._batch_path)
self._cur_node_type=self._batch_get_cur_node_type(self._batch_path)
self._cur_node_id=self._batch_get_cur_node_id(self._batch_path)
return self._batch_curr_state,self._batch_curr_reward,self._done,self._batch_path,self._cur_node_type,self._cur_node_id
def batch_action_mask(self,dropout=0.0):
batch_mask=[]
#print(self._batch_curr_actions.shape)
#print(self._batch_curr_actions)
for actions in self._batch_curr_actions:
act_idxs=list(range(len(actions)))
#print(len(act_idxs)) #187 [0,1,2,...,186]
#随机drop了一部分
if dropout>0 and len(act_idxs)>=5:
keep_size = int(len(act_idxs[1:]) * (1.0 - dropout))
tmp = np.random.choice(act_idxs[1:], keep_size, replace=False).tolist()
act_idxs = [act_idxs[0]] + tmp
act_mask = np.zeros(self.act_dim, dtype=np.bool)
#act_mask = act_mask.bool()
act_mask[act_idxs] = 1
batch_mask.append(act_mask)
return np.vstack(batch_mask)
def print_path(self):
for path in self._batch_path:
msg = 'Path: {}({})'.format(path[0][1], path[0][2])
for node in path[1:]:
msg += ' =={}=> {}({})'.format(node[0], node[1], node[2])
print(msg)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化