加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 2.97 KB
一键复制 编辑 原始数据 按行查看 历史
cyt 提交于 2020-11-04 10:22 . 修改了梯度下降
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties
import matplotlib as mpl
from matplotlib import animation
font = FontProperties(fname=r"c:\windows\fonts\msyh.ttc", size=10)
def sigmoid(z):
for (x, y), val in np.ndenumerate(z):
if val >= 0:
z[x, y] = 1 / (1 + np.exp(-val))
else:
z[x, y] = np.exp(val) / (1 + np.exp(val))
return z
# 预测函数
def h(x, theta):
return sigmoid(np.dot(theta, x))
def gradient(x, y, theta):
return np.dot(y - h(x, theta), x.T)
def reduce(alpha):
if alpha < 1e-6:
return alpha
return alpha / 4
def GD(x, y, theta, alpha=0.5, maxNum=1000000, threshold=1e-3):
for i in range(maxNum):
g = gradient(x, y, theta)
theta = theta + alpha * g
if i % 1000 == 0:
alpha = reduce(alpha)
if np.all(np.abs(g) <= threshold):
print('跳出来',g)
break
return theta
def SGD(x, y, theta, alpha=0.5, maxNum=1000000, threshold=1e-6):
xx = x.T
for i in range(maxNum):
cur = np.random.randint(0, y.size)
g = gradient(np.array([xx[cur]]).T, y[cur], theta)
theta = theta + alpha * g
if i % 1000 == 0:
alpha = reduce(alpha)
return theta
def getData():
x = np.loadtxt(r'..\data\ex4Data\ex4x.dat').T
y = np.loadtxt(r'..\data\ex4Data\ex4y.dat')
# x = np.loadtxt(r'..\data\test\ex4x.dat').T
# y = np.loadtxt(r'..\data\test\ex4y.dat')
x = np.insert(x, 0, np.ones_like(x[0]), axis=0)
return x, y
def predict(x1, x2, theta):
p = theta[0] * np.ones_like(x1) + x1 * theta[1] + x2 * theta[2]
for (x, y), i in np.ndenumerate(p):
if p[x, y] > 0:
p[x, y] = 1
else:
p[x, y] = 0
return p
def fill(x):
cm_light = mpl.colors.ListedColormap(['#e4f6f5', '#ffcc00'])
N, M = 1000, 1000
x1_min, x2_min = np.min(x[1]) - 1, np.min(x[2]) - 1
x1_max, x2_max = np.max(x[1]) + 1, np.max(x[2]) + 1
t1 = np.linspace(x1_min, x1_max, N)
t2 = np.linspace(x2_min, x2_max, M)
x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点
f = plt.pcolormesh(x1, x2, predict(x1, x2, theta), shading='auto', cmap=cm_light)
return f
def update(num):
pass
if __name__ == '__main__':
x, y = getData()
theta = np.random.randn(1, x.shape[0])
print('初始theta=', theta)
theta = SGD(x, y, np.array(theta)).flatten()
print('最终theta=', theta)
fig = plt.figure()
plt.xlabel('x1')
plt.ylabel('x2')
# 画分界
f = fill(x)
# ani = animation.FuncAnimation(fig, f, np.arange(0, 100), interval=100, blit=True)
# 画散点
pos = np.array(np.where(y == 1))
neg = np.array(np.where(y == 0))
p1 = plt.scatter(x[1][pos], x[2][pos], marker='+', label='Admitted')
p2 = plt.scatter(x[1][neg], x[2][neg], marker='^', label='Not admitted')
# f = -(theta[0] * x[0] + theta[1] * x[1]) / theta[2]
plt.legend(loc='best')
plt.show()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化