代码之家  ›  专栏  ›  技术社区  ›  bit_scientist

data.norm()<1000在pytorch中做什么?

  •  13
  • bit_scientist  · 技术社区  · 7 年前

    我正在学习pytorch教程 here . 上面写着

    x = torch.randn(3, requires_grad=True)
    
    y = x * 2
    while y.data.norm() < 1000:
        y = y * 2
    
    print(y)
    
    Out:    
    tensor([-590.4467,   97.6760,  921.0221])
    

    有人能解释一下data.norm()在这里做什么吗? 当我改变 .randn .ones 它的输出是 tensor([ 1024., 1024., 1024.]) .

    3 回复  |  直到 7 年前
        1
  •  10
  •   kmario23 Mazdak    7 年前

    这只是张量的l2范数(也就是欧几里德范数)。下面是一个可复制的说明:

    In [15]: x = torch.randn(3, requires_grad=True)
    
    In [16]: y = x * 2
    
    In [17]: y.data
    Out[17]: tensor([-1.2510, -0.6302,  1.2898])
    
    In [18]: y.data.norm()
    Out[18]: tensor(1.9041)
    
    # computing the norm using elementary operations
    In [19]: torch.sqrt(torch.sum(torch.pow(y, 2)))
    Out[19]: tensor(1.9041)
    

    首先,它将张量中的每个元素平方。 y ,然后求和,最后求平方根。这些运算计算所谓的 L2 or Euclidean norm .

        2
  •  1
  •   Jonathan    7 年前

    基于@kmario23的说法,它将向量的元素乘以2,直到向量的欧几里德距离/大小至少为1000。

    以向量(1,1,1)为例:它增加到(512,512,512),其中l2范数约为886。这小于1000,所以它再次乘以2,变成(1024,1024,1024)。它的震级大于1000,所以它停止了。

        3
  •  0
  •   aimuch    7 年前
    y.data.norm() 
    

    相当于

    torch.sqrt(torch.sum(torch.pow(y, 2)))