加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
config.py 1.79 KB
一键复制 编辑 原始数据 按行查看 历史
marlin-codes 提交于 2022-10-19 15:23 . Add files via upload
import argparse
from hgcn_utils.train_utils import add_flags_from_config
config_args = {
'training_config': {
'log': (None, 'None for no logging'),
'lr': (0.001, 'learning rate'),
'batch-size': (10000, 'batch size'),
'epochs': (500, 'maximum number of epochs to train for'),
'weight-decay': (0.005, 'l2 regularization strength'),
'momentum': (0.95, 'momentum in optimizer'),
'seed': (1234, 'seed for data split'),
'train_seed': (1234, 'seed for training'),
'log-freq': (1, 'how often to compute print train/val metrics (in epochs)'),
'eval-freq': (20, 'how often to compute val metrics (in epochs)'),
'r': (2., 'fermi-dirac decoder parameter for lp'),
't': (1., 'fermi-dirac decoder parameter for lp'),
'i': (1, 'the number of iteration times')
},
'model_config': {
'embedding_dim': (50, 'user item embedding dimension'),
'scale': (0.1, 'scale for init'),
'dim': (50, 'embedding dimension'),
'network': ('resSumGCN', 'choice of StackGCNs, plainGCN, resSumGCN'),
'c': (1, 'hyperbolic radius, set to None for trainable curvature'),
'num-layers': (4, 'number of hidden layers in encoder'),
'margin': (0.1, 'margin value in the metric learning loss'),
'alpha': (20, "scale factor for geometric regularization")
},
'data_config': {
'dataset': ('Amazon-CD', 'which dataset to use'),
'num_neg': (1, 'number of negative samples'),
'test_ratio': (0.2, 'proportion of test edges for link prediction'),
'norm_adj': ('True', 'whether to row-normalize the adjacency matrix'),
}
}
parser = argparse.ArgumentParser()
for _, config_dict in config_args.items():
parser = add_flags_from_config(parser, config_dict)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化