代码之家  ›  专栏  ›  技术社区  ›  Ahmad

如何根据Pytorch中列表的值定义掩码函数

  •  0
  • Ahmad  · 技术社区  · 3 年前

    我想根据张量的值来屏蔽它。在下面的函数中,如果我传递一个范围(第二部分),它可以工作,但我希望有一个包含各种值的列表 prompt_ids

    RuntimeError: Boolean value of Tensor with more than one value is ambiguous
    

    职能:

       def get_prompt_token_fn(self):
            if self.prompt_ids:
                return lambda x: x in self.prompt_ids
            else:
                return lambda x: (x>=self.id_offset)&(x<self.id_offset+self.length)
    

    问题是什么?我如何解决?

    0 回复  |  直到 3 年前
        1
  •  0
  •   Ahmad    3 年前

    在里面 pytorch 1.10 有一个 isin

    def isin(ar1, ar2):
        return (ar1[..., None] == ar2).any(-1)