代码拉取完成,页面将自动刷新
import pandas as pd
import _pickle as pickle
from csv import DictReader
from math import exp, copysign, log, sqrt
from datetime import datetime
class ftrl_proximal(object):
''' Our main algorithm: Follow the regularized leader - proximal
In short,
this is an adaptive-learning-rate sparse logistic-regression with
efficient L1-L2-regularization
Reference:
http://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf
'''
def __init__(self, alpha, beta, L1, L2, D):
# parameters
self.alpha = alpha
self.beta = beta
self.L1 = L1
self.L2 = L2
# feature related parameters
self.D = D
# model
# n: squared sum of past gradients
# z: weights
# w: lazy weights
self.n = [0.] * D
self.z = [0.] * D
self.w = {}
def _indices(self, x):
''' A helper generator that yields the indices in x
The purpose of this generator is to make the following
code a bit cleaner when doing feature interaction.
'''
# # first yield index of the bias term
# yield 0
# then yield the normal indices
for index in x:
yield index
def predict(self, x):
''' Get probability estimation on x
INPUT:
x: features
OUTPUT:
probability of p(y = 1 | x; w)
'''
# parameters
alpha = self.alpha
beta = self.beta
L1 = self.L1
L2 = self.L2
# model
n = self.n
z = self.z
w = {}
# wTx is the inner product of w and x
wTx = 0.
for i in self._indices(x):
sign = -1. if z[i] < 0 else 1. # get sign of z[i]
# build w on the fly using z and n, hence the name - lazy weights
# we are doing this at prediction instead of update time is because
# this allows us for not storing the complete w
if sign * z[i] <= L1:
# w[i] vanishes due to L1 regularization
w[i] = 0.
else:
# apply prediction time L1, L2 regularization to z and get w
w[i] = (sign * L1 - z[i]) / ((beta + sqrt(n[i])) / alpha + L2)
wTx += w[i]
# cache the current w for update stage
self.w = w
# bounded sigmoid function, this is the probability estimation
return 1. / (1. + exp(-max(min(wTx, 35.), -35.)))
def update(self, x, p, y):
''' Update model using x, p, y
INPUT:
x: feature, a list of indices
p: click probability prediction of our model
y: answer
MODIFIES:
self.n: increase by squared gradient
self.z: weights
'''
# parameter
alpha = self.alpha
# model
n = self.n
z = self.z
w = self.w
# gradient under logloss
g = p - y
# update z and n
for i in self._indices(x):
sigma = (sqrt(n[i] + g * g) - sqrt(n[i])) / alpha
z[i] += g - sigma * w[i]
n[i] += g * g
def logloss(p, y):
''' FUNCTION: Bounded logloss
INPUT:
p: our prediction
y: real answer
OUTPUT:
logarithmic loss of p given y
'''
p = max(min(p, 1. - 10e-15), 10e-15)
return -log(p) if y == 1. else -log(1. - p)
def data(path, D):
''' GENERATOR: Apply hash-trick to the original csv row
and for simplicity, we one-hot-encode everything
INPUT:
path: path to training or testing file
D: the max index that we can hash to
YIELDS:
ID: id of the instance, mainly useless
x: a list of hashed and one-hot-encoded 'indices'
we only need the index since all values are either 0 or 1
y: y = 1 if we have a click, else we have y = 0
'''
for t, row in enumerate(DictReader(open(path))):
# process id
ID = row['id']
del row['id']
# process clicks
y = 0.
if 'click' in row:
if row['click'] == '1':
y = 1.
del row['click']
# extract date
# date = int(row['hour'][4:6])
date = datetime.strptime(row['hour'], '%y%m%d%H').weekday()
# turn hour really into hour, it was originally YYMMDDHH
row['hour'] = row['hour'][6:]
# build x
x = []
for key in row:
value = row[key]
# one-hot encode everything with hash trick
index = abs(hash(key + '_' + value)) % D
x.append(index)
yield t, date, ID, x, y
D = 2 ** 20 # number of weights to use
# 加载训练了5个epoch的模型
model = pickle.load(open("model/model_ftrl_e4", 'rb'))
p_list = []
ids = []
test = "./data/test"
# 对测试数据进行预测
for t, date, ID, x, y in data(test, D): # data is a generator
p = model.predict(x)
if t % 100000 == 0:
print(t)
pred = model.predict(x)
p_list.append(pred)
ids.append(ID)
# 将预测数据和ID合并,输出要求的格式。提交Kaggle官网,最终得分 Private Score 0.47310
res = pd.concat([pd.Series(ids), pd.Series(p_list)], axis=1)
res = res.rename(columns={0: "id", 1: "click"})
res.to_csv("pred_ftrl_1", index=False)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。