代码之家  ›  专栏  ›  技术社区  ›  Izzo

用线性模型逼近平方函数时,PyTorch不收敛

  •  -1
  • Izzo  · 技术社区  · 6 年前

    我正在努力学习一些Pytork,并参考了这次讨论 here

    作者提供了一段最短的代码,说明了如何使用Pytork来求解被随机噪声污染的未知线性函数。

    这段代码对我来说运行良好。

    然而,当我改变函数,希望t=X^2时,参数似乎不会收敛。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.autograd import Variable
    
    # Let's make some data for a linear regression.
    A = 3.1415926
    b = 2.7189351
    error = 0.1
    N = 100 # number of data points
    
    # Data
    X = Variable(torch.randn(N, 1))
    
    # (noisy) Target values that we want to learn.
    t = X * X + Variable(torch.randn(N, 1) * error)
    
    # Creating a model, making the optimizer, defining loss
    model = nn.Linear(1, 1)
    optimizer = optim.SGD(model.parameters(), lr=0.05)
    loss_fn = nn.MSELoss()
    
    # Run training
    niter = 50
    for _ in range(0, niter):
        optimizer.zero_grad()
        predictions = model(X)
        loss = loss_fn(predictions, t)
        loss.backward()
        optimizer.step()
    
        print("-" * 50)
        print("error = {}".format(loss.data[0]))
        print("learned A = {}".format(list(model.parameters())[0].data[0, 0]))
        print("learned b = {}".format(list(model.parameters())[1].data[0]))
    

    当我执行这段代码时,新的A和b参数似乎是随机的,因此不会收敛。我认为这应该收敛,因为你可以用斜率和偏移函数来近似任何函数。我的理论是我没有正确使用Pytork。

    有人能确定我的电脑有问题吗 t = X * X + Variable(torch.randn(N, 1) * error) 代码行?

    0 回复  |  直到 6 年前
        1
  •  3
  •   Shai    6 年前

    不能用线性函数拟合二次多项式。你不能期望更多的是随机的(因为你有来自多项式的随机样本)。
    你能做的就是尝试两种输入, x x^2 并从中得到满足:

    model = nn.Linear(2, 1)  # you have 2 inputs now
    X_input = torch.cat((X, X**2), dim=1)  # have 2 inputs per entry
    # ...
    
        predictions = model(X_input)  # 2 inputs -> 1 output
        loss = loss_fn(predictions, t)
        # ...
        # learning t = c*x^2 + a*x + b
        print("learned a = {}".format(list(model.parameters())[0].data[0, 0]))
        print("learned c = {}".format(list(model.parameters())[0].data[0, 1])) 
        print("learned b = {}".format(list(model.parameters())[1].data[0]))