代码之家  ›  专栏  ›  技术社区  ›  ThatQuantDude

经过最后一个数据点的曲线拟合

  •  1
  • ThatQuantDude  · 技术社区  · 7 年前

    我试图将曲线拟合到一组数据点,但希望保留某些特征。

    就像在这张图中,我得到的曲线几乎都是线性的,有些曲线不是。我需要一个函数形式在给定的数据点之间插值或超过最后一个给定点。

    曲线是用简单的回归法创建的

    def func(x, d, b, c):
        return c + b * np.sqrt(x) + d * x
    

    enter image description here

    我现在的问题是,确保正斜率超过最后一个数据点的最佳方法是什么????在我的应用程序中,在增加容量的同时降低成本是没有意义的,即使数据表明是这样。

    我想把订单尽量放低,也许3点还可以。

    用于创建负斜率曲线的数据是

    x_data = [     100,      560,      791,     1117,     1576,     2225,
           3141,     4434,     6258,     8834,    12470,    17603,
          24848,    35075,    49511,    69889,    98654,   139258,
         196573,   277479,   391684,   552893,   780453,  1101672,
        1555099,  2195148,  3098628,  4373963,  6174201,  8715381,
       12302462, 17365915]
    y_data = [  7,   8,   9,  10,  11,  12,  14,  16,  21,  27,  32,  30,  31,
        38,  49,  65,  86, 108, 130, 156, 183, 211, 240, 272, 307, 346,
       389, 436, 490, 549, 473, 536]
    

    积极的一面

    x_data = [     100,      653,      950,     1383,     2013,     2930,
           4265,     6207,     9034,    13148,    19136,    27851,
          40535,    58996,    85865,   124969,   181884,   264718,
         385277,   560741,   816117,  1187796,  1728748,  2516062,
        3661939,  5329675,  7756940, 11289641, 16431220, 23914400,
       34805603, 50656927]
    y_data = [  6,   6,   7,   7,   8,   8,   9,  10,  11,  12,  14,  16,  18,
        21,  25,  29,  35,  42,  50,  60,  72,  87, 105, 128, 156, 190,
       232, 284, 347, 426, 522, 640]
    

    曲线拟合通过使用

    popt, pcov = curve_fit(func, x_data, y_data)
    

    为了阴谋

    plt.plot(xdata, func(xdata, *popt), 'g--', label='fit: a=%5.3f, b=%5.3f, c=%5.3f' % tuple(popt))
    plt.plot(x_data, y_data, 'ro')
    plt.xlabel('Volume')
    plt.ylabel('Costs')
    plt.show()
    
    1 回复  |  直到 7 年前
        1
  •  1
  •   mikuszefski    7 年前

    一个简单的解决方案可能就是这样:

    import matplotlib.pyplot as plt
    import numpy as np
    from scipy.optimize import least_squares
    
    def fit_function(x, a, b, c, d):
        return a**2 + b**2 * x + c**2 * abs(x)**d 
    
    def residuals( params, xData, yData):
        diff = [ fit_function(x, *params ) - y for x, y in zip( xData, yData ) ]
        return diff
    
    fit1 = least_squares( residuals, [ .1, .1, .1, .5 ], loss='soft_l1', args=( x1Data, y1Data ) )
    print fit1.x
    fit2 = least_squares( residuals, [ .1, .1, .1, .5 ], loss='soft_l1', args=( x2Data, y2Data ) )
    print fit2.x
    
    testX1 = np.linspace(0, 1.1 * max( x1Data ), 100 )
    testX2 = np.linspace(0, 1.1 * max( x2Data ), 100 )
    testY1 = [ fit_function( x, *( fit1.x ) ) for x in testX1 ]
    testY2 = [ fit_function( x, *( fit2.x ) ) for x in testX2 ]
    
    fig = plt.figure()
    ax = fig.add_subplot( 1, 1, 1 )
    ax.scatter( x1Data, y1Data )
    ax.scatter( x2Data, y2Data )
    ax.plot( testX1, testY1 )
    ax.plot( testX2, testY2 )
    plt.show()
    

    提供

    >>[ 1.00232004e-01 -1.10838455e-04  2.50434266e-01  5.73214256e-01]
    >>[ 1.00104293e-01 -2.57749592e-05  1.83726191e-01  5.55926678e-01]
    

    soft fit

    它只是将参数作为平方,从而确保正斜率。当然,如果禁止跟随数据集1末尾的递减点,则拟合会变得更差。关于这一点,我认为这些只是统计上的异常值。因此,我用 least_squares ,它可以处理软损失。见 this doc 详细情况。根据实际数据集的情况,我会考虑删除它们。最后,我希望零容量产生零成本,所以fit函数中的常数项似乎没有意义。

    所以如果函数只是类型 a**2 * x + b**2 * sqrt(x) 它看起来像:

    simplified

    其中绿色图形是 leastsq ,即没有 f_scale 选择权 最小二乘法 .