代码之家  ›  专栏  ›  技术社区  ›  zlon

线性分段拟合(3条线)对初始节点位置过于敏感

  •  1
  • zlon  · 技术社区  · 8 年前

    我有一些实验数据(信号与时间): enter image description here

    我想把这些曲线拟合成线性分段函数。我需要它,因为:

    1) 我对“跳跃”的时间感兴趣(中间部分)

    2) 我想整理一下右上角的曲线。

    要进行拟合,我尝试使用解决方案,如中所述 stackoverflow ,但这个方法(或至少我的实现)对初始节点位置非常敏感。我的代码和结果:

     %% Generate dummy data
    x=linspace(-2,2,100); y=sinh(x)./cosh(x).^2;
    % add noise
    y=y+(rand(1,numel(y))-0.5)*0.1;
    %% Estimate knots
    d=(max(x)-min(x));
    X0_badEstimation=[min(x)+d/5,  min(x)+4*d/5];
    X0_goodEstimation=[min(x)+d/3,  min(x)+2*d/3];
    %% Estimate piecewise fit
    idx=1:min(10,numel(x));
    p1=polyfit(x(idx),y(idx),1);
    idx=max(0,round(numel(x)/2-10)):min(round(numel(x)/2+10),numel(x));
    p2=polyfit(x(idx),y(idx),1);
    idx=max(0,(numel(x)-10)):numel(x);
    p3=polyfit(x(idx),y(idx),1);
    %% estimate slopes
    s1=p1(1); s2=p2(1); s3=p3(1);
    %% estimate offsets
    o1=p1(2); o2=p2(2); o3=p3(2);
    %% model in form
    % y=(o1+s1*x)*((x0-x)>=0)+(o2+s2*x)*((x0-x)<0)*((x0+delta)-x>=0)+(o3+s3*x)*(((x0+delta)-x)<0)
    model=@(P,x) (P(6)+P(3).*x).*((P(1)-x)>=0)+...
        (P(7)+P(4).*x).*((P(1)-x)<0).*((P(1)+P(2))-x>=0)+...
        (P(8)+P(5).*x).*(((P(1)+P(2))-x)<0);
    %% Initial parameters:
    x0=X0_goodEstimation(1);
    delta=diff(X0_goodEstimation);
    P0_goodEstimation=[x0,delta,s1,s2,s3,o1,o2,o3];
    x0=X0_badEstimation(1);
    delta=diff(X0_badEstimation);
    P0_badEstimation=[x0,delta,s1,s2,s3,o1,o2,o3];
    %% fit it!
    Pfit_goodEstimation = lsqcurvefit(model,P0_goodEstimation,x,y);
    Pfit_badEstimation = lsqcurvefit(model,P0_badEstimation,x,y);
    %% plot results
    hold all
    plot(x,y,'LineWidth',3)
    plot(x,model(Pfit_badEstimation,x))
    plot(x,model(Pfit_goodEstimation,x))
    legend({'data','bad knots estimation','good knots estimation'},'FontSize',24)
    

    enter image description here

    关于代码的几点评论:

    1) 我用来生成虚拟数据的函数: y=sinh(x)./cosh(x).^2 对我来说没有意义,我是凭经验发现的,只用于提供运行的代码。

    2) 我尝试了不同的方法来更好地估计结的位置,但在实际数据上我没有成功,所以这里我使用一些简单的方法JU进行演示。

    问题:

    你能帮我试穿一下吗?我在实现中是否有错误,或者在我的情况下应该使用不同的方法?

    1 回复  |  直到 8 年前
        1
  •  2
  •   mikuszefski    8 年前

    使用连续经验拟合函数怎么样? 超噪声数据要么不能正确收敛,产生误差,要么在一个或多个参数中存在较大的拟合误差。您还可以检查拟合结果的卡方误差。

    # -*- coding: utf-8 -*-
    import matplotlib.pyplot as plt
    import numpy as np
    from scipy.optimize import curve_fit
    
    def fit_fun( x, a, b, c, d, e ):
        return a + b * x + c * np.tanh( d * ( x - e ) )
    
    # some data
    xData= np.linspace( -2, 7 ,37 )
    yData = dict()
    yData[ 1 ] = np.fromiter( ( np.random.normal(scale = 0.25 ) -10 + 0.3 * x + 3 * np.tanh( 4.5 * ( x - 2.3 ) ) for x in xData ), np.float )
    yData[ 2 ] = np.fromiter( ( np.random.normal(scale = 2.50 ) + 3 - 1.0 * x + 0.3 * np.tanh( 4.5 * ( x - 2.3 ) ) for x in xData ), np.float )
    yData[ 3 ] = np.fromiter( ( np.random.normal(scale = 0.25 ) -10 - 0.3 * x + 2 * np.tanh( 3.5 * ( x - 2.8 ) )+ 2 * np.tanh( 4.8 * ( x - 1.8 ) ) for x in xData ), np.float )
    yData[ 4 ] = np.fromiter( ( np.random.normal(scale = 0.25 ) -10 + 0.3 * x + 3 * np.tanh( -.85 * ( x - 2.3 ) ) for x in xData ), np.float )
    
    # plotting 
    fig = plt.figure()
    ax = dict()
    for i in range( 1, 5 ):
        ax[i] = fig.add_subplot( 2, 2, i)
        ax[i].plot( xData, yData[i] )
    
    # fitting and plotting
    sol = dict()
    pcov = dict()
    for i in range( 1, 5 ):
        aStart = np.mean( yData[i] )
        bStart = 0
        cStart = max( yData[i] ) - min( yData[i] )
        dStart = 1
        eStart = ( max( xData ) + min( xData ) ) / 2.
        try:
            sol[i], pcov[i]  = curve_fit( fit_fun, xData, yData[i] , ( aStart, bStart, cStart, dStart, eStart), maxfev=5000 )
            ax[i].plot( xData, fit_fun( xData, *sol[i] ) )
            print pcov[i].diagonal()
        except RuntimeError:
            print "could not fit data {}".format(i)
    plt.show()
    

    提供以下任一选项:

    With fail

    >>[8.67452955e-03 1.20595620e-03 9.66526422e-03 3.29572838e-01 3.02484309e-04]
    >>could not fit data 2
    >>[0.06241828 0.00990755 0.1092549  0.02299717 0.00160222]
    >>[0.07924285 0.01274575 0.19154392 0.00806664 0.00194379]
    

    或: enter image description here

    >>[8.29671770e-03 1.15356171e-03 9.10475233e-03 6.58578994e-01 2.20470826e-04]
    >>[3.02704593e+02 1.57603966e+01 8.43995977e+02 4.56342636e-01 2.72302001e+00]
    >>[0.0467695  0.00741241 0.08193284 0.01793236 0.0012683 ]
    >>[0.0475788  0.00819175 0.11240722 0.01152346 0.00221799]
    
    推荐文章