> 文章列表 > 梯度下降法演示

梯度下降法演示

梯度下降法演示

# (x-2)2 = 0
from tqdm import trangeepoch = 1000
lr = 0.09
x = 5  # 初始值, 凯明初始化, # 何凯明
label = 0for e in trange(epoch):pre = (x - 2)  2loss = (pre - label)  2delta_x = 2*(pre - label) * (x - 2)x = x - delta_x * lrprint(x)
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 386821.36it/s]
1.9474942856120552

可以使用matplotlib和FuncAnimation库来生成动态图。这里是一个简单的示例代码,可以生成(x-2)^2=0的根随迭代次数变化的动态图:

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from tqdm import trange
from matplotlib.animation import PillowWriterepoch = 1000
lr = 0.09
x = 5
label = 0fig, ax = plt.subplots()
ax.set_xlim(0, epoch)
ax.set_ylim(0, 10)x_data = []
y_data = []line, = ax.plot([], [])def update(frame):global xglobal y_dataglobal x_datapre = (x-2)  2loss = (pre - label)  2delta_x = 2*(pre - label) * (x-2)x = x - delta_x * lrx_data.append(frame)y_data.append(x)line.set_data(x_data, y_data)return line,ani = FuncAnimation(fig, update, frames=trange(epoch), blit=True)ax.set_xlabel('Epoch')
ax.set_ylabel('X')writer = PillowWriter(fps=150)
ani.save('animation.gif', writer=writer)

在上述代码中,我们首先使用matplotlib创建了一个图形窗口,并设置了x轴和y轴的范围。然后,我们定义了一个update函数,该函数在每次迭代中计算损失和更新x的值,并将当前的x和迭代次数分别添加到x_data和y_data列表中。最后,我们使用FuncAnimation函数创建了一个动画对象,并将update函数作为参数传递进去。

在调用FuncAnimation函数时,我们使用trange函数来生成一个迭代器,该迭代器会在每次迭代时更新动画。blit参数设置为True表示只更新变化的部分,可以提高动画的效率。最后,我们调用plt.show()函数显示动画图形窗口。

运行这段代码,你将会看到一个动态图,其中x的值会随着迭代次数的增加而趋近于2,即(x-2)^2=0的根。