代码之家  ›  专栏  ›  技术社区  ›  Ryan

运行时错误:在/pytorch/torch/lib/th/generic/thtenformath.c:2864处张量大小不一致

  •  0
  • Ryan  · 技术社区  · 7 年前

    我正试图构建一个数据加载器,这就是它的样子

    `class WhaleData(Dataset):
    def __init__(self, data_file, root_dir , transform = None):
        self.csv_file = pd.read_csv(data_file)
        self.root_dir = root_dir
        self.transform = transforms.Resize(224)
    
    def __len__(self):
        return len(os.listdir(self.root_dir))
    
    def __getitem__(self, index):
        image = os.path.join(self.root_dir, self.csv_file['Image'][index])
        image = Image.open(image)
        image = self.transform(image)
        image = np.array(image)
        label  = self.csv_file['Image'][index]
        sample = {'image': image, 'label':label}
        return sample
    
    trainset  = WhaleData(data_file = '/mnt/55-91e8-b2383e89165f/Ryan/1234/train.csv', 
         root_dir = '/mnt/4d55-91e8-b2383e89165f/Ryan/1234/train')
    train_loader = torch.utils.data.DataLoader(trainset , batch_size = 4, shuffle =True,num_workers= 2)
    for i, batch in enumerate(train_loader):
          (i, batch)
    

    当我尝试运行这段代码时,我得到了这个错误,我确实得到了错误的性质,即我的所有图像可能都不是同一形状,并且我的图像都不是同一形状,但是如果我没有错,那么只有当我将它们送入网络时才会出现错误,因为这些图像都是不同的。NT形状,但为什么它会在这里抛出一个错误? 任何关于我哪里出错的建议都会非常有帮助, 如果需要,我很乐意提供任何额外的信息,

    谢谢

    RuntimeError: Traceback (most recent call last):
      File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
      File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 116, in default_collate
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
      File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 116, in <dictcomp>
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
      File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 105, in default_collate
    return torch.stack([torch.from_numpy(b) for b in batch], 0)
      File "/usr/local/lib/python3.5/dist-packages/torch/functional.py", line 64, in stack
        return torch.cat(inputs, dim)
    RuntimeError: inconsistent tensor sizes at /pytorch/torch/lib/TH/generic    /THTensorMath.c:2864
    
    1 回复  |  直到 7 年前
        1
  •  1
  •   benjaminplanche    7 年前

    当pytorch试图将图像堆叠成单个批处理张量时,就会出现错误(参见。 torch.stack([torch.from_numpy(b) for b in batch], 0) 从你的轨迹)。正如你提到的,由于图像的形状不同,叠加失败(即张量 (B, H, W) 只能通过堆叠创建 B 张量,如果这些张量都有形状 (H, W) )


    注:我不完全确定,但设置 batch_size=1 对于 torch.utils.data.DataLoader(...) 可以删除此特定错误,因为它可能不需要调用 torch.stack() 再也没有了)。