假设,我有一个PyTorch张量
a
形状的
ba,c,h,w
并希望将几行
一
在另一个张量中给出索引
b
形状的
ba,2
属于
dtype=torch.int16
和
b[batch, 0] <= b[batch, 1]
。
for循环的方法是:
for batch in range(ba):
a[batch,:,0:b[batch,0],:] = 0 # stmnt 1
a[batch,:,b[batch,1]:,:] = 0 # stmnt 2
PyTorch有更快的方法吗?
具体来说,首先将stmnt1和stmnt2组合成一行,告诉PyTorch将
一
除了
a[batch,:,b[batch,0]:b[batch,1],:]
零其次,如果可以在不需要使用
for
环