代码之家  ›  专栏  ›  技术社区  ›  rbaleksandar

如何使用数据加载器在PyTorch __getitem__()中使用收益率?

  •  0
  • rbaleksandar  · 技术社区  · 4 年前

    我未来的意图是在训练、验证和测试期间,只在需要时将数据加载到GPU内存中。由于VRAM的容量有限,我决定以以下方式构建数据集:

    • 它只包含我要使用的图像的路径列表
    • 实际图像数据仅在 __getitem__ 呼叫

    目前,我的图像相当小(大约200x200px),每个图像都被分割成补丁(总是36x36px),每个补丁都经过某种转换。然而,在未来,我将切换到更大的图像(想想1000x1000px甚至更多),在这种情况下,补丁的数量将增加。

    yield 这似乎是一个很好的选择,因为它提供了一个生成器,可以根据需要加载数据。然而,我是新手 产量 和PyTorch,所以我正在努力将两者结合起来。

    以下是数据集:

    class CustomDataset(torch.utils.data.Dataset):
      def __init__(self, images):
        '''
        Dataset which loads image from a file and upon retrieval using __getitem__
        yield a pair of stacks with patches
        '''
        self.images = images
    
      def __len__(self):
        return len(self.images)
    
    
      def __getitem__(self, idx):
        # Load image from file
        img = cv2.imread(self.images[idx], cv2.IMREAD_GRAYSCALE)
        # Extract patches from image (list of images)
        patches = self.extract_patches(img)
        # Using original patches create two separate lists of equal length
        #  - list "patches_x" contains patches that have underwent transformation X
        #  - list "patches_y" contains patches that have underwent transformation Y
        # Each transformation is using OpenCV or numpy in general, so a final conversion
        # to Tensor is required using "from_numpy"
        patches_x = [from_numpy(transform_X(patch) for patch in patches]
        patches_y = [from_numpy(transform_Y(patch) for patch in patches]
    
        # Yield pair
        yield stack(patches_x), stack(patches_y)
    

    这里的转换并不重要(这就是为什么我没有包括这些转换)。每次检索到新项目时,都会加载相应路径中的图像(使用OpenCV),将其拆分为补丁,然后对两个补丁列表进行不同的转换(例如调整大小、模糊等)。最后,将这两个列表转换为 stack 结构和 产量 ed。

    我的问题是如何使用 Dataloader .我显然需要使用自定义 collate_fn

    def custom_collate(batch):
      # TODO
    
    dataset_train = CustomDataset('./data/train/')
    dataloader_train = DataLoader(dataset=dataset_train, batch_size=64, collate_fn=custom_collate)
    

    如果我只是使用 数据加载器 具有的实例 default_collate

    for sample_patches in iter(dataloader_train):
        print(sample_patches)
    

    我出错了

    Traceback (most recent call last):
      File "D:\Projects\networks\net\net.py", line 255, in <module>
        for sample_patches in iter(dataloader_train):
      File "D:\env\ml\lib\site-packages\torch\utils\data\dataloader.py", line 530, in __next__
        data = self._next_data()
      File "D:\env\ml\lib\site-packages\torch\utils\data\dataloader.py", line 570, in _next_data
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
      File "D:\env\ml\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
        return self.collate_fn(data)
      File "D:\env\satellite\lib\site-packages\torch\utils\data\_utils\collate.py", line 180, in default_collate
        raise TypeError(default_collate_err_msg_format.format(elem_type))
    TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'generator'>
    

    这是可以理解的,因为 __getitem__() 由于 产量 声明结尾,而不是张量。然后我再次使用 产量 可能完全不正确。。。

    网络定义为

    from torch import nn
    from collections import OrderedDict
    
    class Network(nn.Module):
        def __init__(self) -> None:
            super(Network, self).__init__()
            self.model = nn.Sequential(OrderedDict([
                ('conv1', nn.Conv2d(1, 64, 7, stride=1, padding=3)),
                ('relu1', nn.ReLU(True)),
                ('conv2', nn.Conv2d(64, 32, 5, stride=1, padding=2)),
                ('relu2', nn.ReLU(True)),
                ('conv3', nn.Conv2d(32, 32, 3, stride=1, padding=1)),
                ('relu3', nn.ReLU(True)),
                ('conv4', nn.Conv2d(32, 1, 3, stride=1, padding=1))
            ]))
    
        def forward(self, x):
            x = self.model(x)
            return x
    

    更新:

    我当前代码的问题在于 batch_size 只影响我加载的图像数量(如图像路径的数量),因为那些直接连接到 __len__() __getitem__() 。补丁是 __getitem__() 因此数据加载器并不真正了解这些。

    示例:

    让我们吃以下 collate 作用

    def custom_collate(batch):
        return batch
    

    它除了返回生成器之外什么都不做,由 产量 在里面 __getitem__()

    配置 数据加载器

    dataloader_train = DataLoader(dataset=dataset_train, batch_size=1, collate_fn=custom_collate)
    

    并对其进行迭代

    for sample in iter(dataloader_train):
        print(sample)
    

    给予

    [<generator object TEN_DataTrain.__getitem__ at 0x000002095B154D60>]
    [<generator object TEN_DataTrain.__getitem__ at 0x000002095B154DD0>]
    [<generator object TEN_DataTrain.__getitem__ at 0x000002095B154D60>]
    [<generator object TEN_DataTrain.__getitem__ at 0x000002095B154DD0>]
    ...
    

    其中 [<generator object>] 实例等于图像路径的数量。

    如果我增加 批次_大小 2 我会得到

    [<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684D60>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684DD0>]
    [<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684E40>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684EB0>]
    [<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684D60>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684DD0>]
    ...
    

    我想这就是批处理的工作方式,但反过来说,这不适合我,因为我可能很少有大的图像,因此会产生很多补丁。

    0 回复  |  直到 4 年前
    推荐文章