代码之家  ›  专栏  ›  技术社区  ›  Daniel

在使用matplotlib演示梯度下降时,如何绘制典型的碗形状?

  •  1
  • Daniel  · 技术社区  · 8 年前

    在演示梯度下降时,我们通常会看到下面的碗状图。此外,据说使用log_损失而不是平方误差,我们可以更容易地找到损失的最小值,因为使用平方误差作为损失函数可能会导致多个局部最小值。

    因此,我想绘制碗形状图,如下所示。

    然而,我只成功地策划了以下内容 enter image description here

    这是我的代码,有人能帮我修改吗?谢谢

    from mpl_toolkits.mplot3d.axes3d import Axes3D
    import matplotlib.pyplot as plt
    from matplotlib import cm
    import numpy as np
    import math
    
    
    fig, ax1 = plt.subplots(1, 1, figsize=(8, 5), subplot_kw={'projection': '3d'})
    
    # Get the test data
    x1 = 1
    x2 = 1
    y = 0.8
    w = np.linspace(-10,10,100)
    # w = np.random.random(100)
    wl = np.linspace(-10,10,100)
    # wl = np.random.random(100)
    w1 = np.ones((100,100))
    w2 = np.ones((100,100))
    for idx in range(100):
        w1[idx] = w1[idx]*w
        w2[:,idx] = w2[:,idx]*wl
    
    L = []
    for i in range(w1.shape[0]):
        for j in range(w1.shape[1]):
            a = w1[i,j]*x1 + w2[i,j]*x2
            f = 1/(1+math.exp(-a))
            l = -(y*math.log(f)+(1-y)*math.log(1-f))
            # l = (1/2)*(f-y)**2
            L.append(l)
    l = np.array(L).reshape(w1.shape)
    
    
    ax1.plot_wireframe(w1,w2,l)
    ax1.set_title("plot backpropogation")
    
    
    plt.tight_layout()
    plt.show()
    
    1 回复  |  直到 8 年前
        1
  •  4
  •   ImportanceOfBeingErnest    8 年前

    下面忽略了问题中的公式,可能与任何实际问题都完全无关。它只是展示了如何绘制一个碗。

    绘制碗的一种方法是使用绕z轴旋转对称的函数。

    例如:

    from mpl_toolkits.mplot3d.axes3d import Axes3D
    import matplotlib.pyplot as plt
    import numpy as np
    
    fig, ax1 = plt.subplots(figsize=(8, 5), 
                            subplot_kw={'projection': '3d'})
    
    alpha = 0.8
    r = np.linspace(-alpha,alpha,100)
    X,Y= np.meshgrid(r,r)
    l = 1./(1+np.exp(-(X**2+Y**2)))
    
    ax1.plot_wireframe(X,Y,l)
    ax1.set_title("plot")
    
    plt.show()
    

    enter image description here