我像这样计算我的准确性
(outputs.round() == targets).all(dim=2).all(dim=1).sum().item() / outputs.shape[0]
具有 outputs 和 targets 有形状的 NxAxB 。 N 是批量大小。 剩下的部分是预测/真值,我想看看它们是否相同。
outputs
targets
NxAxB
N
目前我正在使用 .all(dim=2).all(dim=1) 。 现在的问题是,如果我有一个不同的模型,形状会有所不同。他们会 NxA ,所以我目前的方法不起作用,因为 dim=2 不存在。
.all(dim=2).all(dim=1)
NxA
dim=2
(outputs.round() == targets).all(dim=1).sum().item() / outputs.shape[0]
,会起作用,但再次仅适用于第二个模型。
理想情况下,我想申请 .all 除了第一个维度(批量维度)之外的所有维度。 我该怎么做?
.all
要推广到任意数量的维度,可以从 dim=1 向外使用 torch.flatten ,然后应用 all 和 mean :
dim=1
torch.flatten
all
mean
>>> (outputs.round() == targets).flatten(1).all(1).float().mean()
笔记 : torch.flatten(dim=1) 将使张量变平 dim=1 到 dim=-1 。
torch.flatten(dim=1)
dim=-1