代码之家  ›  专栏  ›  技术社区  ›  mxcx

将torch.all应用于除第一个维度外的所有维度

  •  0
  • mxcx  · 技术社区  · 2 年前

    我像这样计算我的准确性

    (outputs.round() == targets).all(dim=2).all(dim=1).sum().item() / outputs.shape[0]
    

    具有 outputs targets 有形状的 NxAxB N 是批量大小。 剩下的部分是预测/真值,我想看看它们是否相同。

    目前我正在使用 .all(dim=2).all(dim=1) 。 现在的问题是,如果我有一个不同的模型,形状会有所不同。他们会 NxA ,所以我目前的方法不起作用,因为 dim=2 不存在。

    (outputs.round() == targets).all(dim=1).sum().item() / outputs.shape[0]
    

    ,会起作用,但再次仅适用于第二个模型。

    理想情况下,我想申请 .all 除了第一个维度(批量维度)之外的所有维度。 我该怎么做?

    1 回复  |  直到 2 年前
        1
  •  1
  •   Ivan Jeffrey Zhao    2 年前

    要推广到任意数量的维度,可以从 dim=1 向外使用 torch.flatten ,然后应用 all mean :

    >>> (outputs.round() == targets).flatten(1).all(1).float().mean()
    

    笔记 : torch.flatten(dim=1) 将使张量变平 dim=1 dim=-1