加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ctr_fm_ftrl.py 1.98 KB
一键复制 编辑 原始数据 按行查看 历史
pan 提交于 2020-12-03 16:47 . init
from FM_FTRL_machine import *
from datetime import datetime
import random
import _pickle as pickle
import psutil
# %load runmodel_example.py
# time pypy-2.4 -u runmodel.py | tee output_0.txt
#### RANDOM SEED ####
random.seed(5) # seed random variable for reproducibility
#####################
####################
#### PARAMETERS ####
####################
reportFrequency = 100000
trainingFile = "./data/train"
fm_dim = 4
fm_initDev = .1
hashSalt = "salty"
alpha = .05
beta = 1.
alpha_fm = .05
beta_fm = 1.
p_D = 22
D = 2 ** p_D
L1 = 0.01
L2 = 5.0
L1_fm = .01
L2_fm = 5.0
n_epochs = 5
####
# initialize a FM learner
learner = FM_FTRL_machine(fm_dim, fm_initDev, L1, L2, L1_fm, L2_fm, D, alpha, beta, alpha_fm=alpha_fm, beta_fm=beta_fm)
print("Start Training:")
start = datetime.now()
for e in range(n_epochs):
start1 = datetime.now()
# if it is the first epoch, then don't use L1_fm or L2_fm
if e == 0:
learner.L1_fm = 0.
learner.L2_fm = 0.
else:
learner.L1_fm = L1_fm
learner.L2_fm = L2_fm
cvLoss = 0.
cvCount = 0.
progressiveLoss = 0.
progressiveCount = 0.
arr = []
for t, date, ID, x, y in data(trainingFile, D, hashSalt):
p = learner.predict(x)
loss = logLoss(p, y)
learner.update(x, p, y)
progressiveLoss += loss
progressiveCount += 1.
if t % reportFrequency == 0:
print("Epoch %d\tcount: %d\t Loss: %f, time: %s, memory %s%%" % (e, t, progressiveLoss / progressiveCount, str(datetime.now() - start1), psutil.virtual_memory().percent))
start1 = datetime.now()
# print(p)
# print("Epoch %d\tcount: %d\tProgressive Loss: %f\t progressiveCount :%f" % (e, t, progressiveLoss , progressiveCount))
print(progressiveLoss / progressiveCount)
print("Epoch %d finished.\cvLoss: %f\t cvCount: %f \t elapsed time: %s" % (e, cvLoss, cvCount, str(datetime.now() - start)))
pickle.dump(learner, open("model/model_fm_e" + str(e), 'wb'))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化