我未来的意图是在训练、验证和测试期间,只在需要时将数据加载到GPU内存中。由于VRAM的容量有限,我决定以以下方式构建数据集:
-
它只包含我要使用的图像的路径列表
-
实际图像数据仅在
__getitem__
呼叫
目前,我的图像相当小(大约200x200px),每个图像都被分割成补丁(总是36x36px),每个补丁都经过某种转换。然而,在未来,我将切换到更大的图像(想想1000x1000px甚至更多),在这种情况下,补丁的数量将增加。
yield
这似乎是一个很好的选择,因为它提供了一个生成器,可以根据需要加载数据。然而,我是新手
产量
和PyTorch,所以我正在努力将两者结合起来。
以下是数据集:
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, images):
'''
Dataset which loads image from a file and upon retrieval using __getitem__
yield a pair of stacks with patches
'''
self.images = images
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# Load image from file
img = cv2.imread(self.images[idx], cv2.IMREAD_GRAYSCALE)
# Extract patches from image (list of images)
patches = self.extract_patches(img)
# Using original patches create two separate lists of equal length
# - list "patches_x" contains patches that have underwent transformation X
# - list "patches_y" contains patches that have underwent transformation Y
# Each transformation is using OpenCV or numpy in general, so a final conversion
# to Tensor is required using "from_numpy"
patches_x = [from_numpy(transform_X(patch) for patch in patches]
patches_y = [from_numpy(transform_Y(patch) for patch in patches]
# Yield pair
yield stack(patches_x), stack(patches_y)
这里的转换并不重要(这就是为什么我没有包括这些转换)。每次检索到新项目时,都会加载相应路径中的图像(使用OpenCV),将其拆分为补丁,然后对两个补丁列表进行不同的转换(例如调整大小、模糊等)。最后,将这两个列表转换为
stack
结构和
产量
ed。
我的问题是如何使用
Dataloader
.我显然需要使用自定义
collate_fn
def custom_collate(batch):
# TODO
dataset_train = CustomDataset('./data/train/')
dataloader_train = DataLoader(dataset=dataset_train, batch_size=64, collate_fn=custom_collate)
如果我只是使用
数据加载器
具有的实例
default_collate
for sample_patches in iter(dataloader_train):
print(sample_patches)
我出错了
Traceback (most recent call last):
File "D:\Projects\networks\net\net.py", line 255, in <module>
for sample_patches in iter(dataloader_train):
File "D:\env\ml\lib\site-packages\torch\utils\data\dataloader.py", line 530, in __next__
data = self._next_data()
File "D:\env\ml\lib\site-packages\torch\utils\data\dataloader.py", line 570, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "D:\env\ml\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
return self.collate_fn(data)
File "D:\env\satellite\lib\site-packages\torch\utils\data\_utils\collate.py", line 180, in default_collate
raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'generator'>
这是可以理解的,因为
__getitem__()
由于
产量
声明结尾,而不是张量。然后我再次使用
产量
可能完全不正确。。。
网络定义为
from torch import nn
from collections import OrderedDict
class Network(nn.Module):
def __init__(self) -> None:
super(Network, self).__init__()
self.model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 64, 7, stride=1, padding=3)),
('relu1', nn.ReLU(True)),
('conv2', nn.Conv2d(64, 32, 5, stride=1, padding=2)),
('relu2', nn.ReLU(True)),
('conv3', nn.Conv2d(32, 32, 3, stride=1, padding=1)),
('relu3', nn.ReLU(True)),
('conv4', nn.Conv2d(32, 1, 3, stride=1, padding=1))
]))
def forward(self, x):
x = self.model(x)
return x
更新:
我当前代码的问题在于
batch_size
只影响我加载的图像数量(如图像路径的数量),因为那些直接连接到
__len__()
和
__getitem__()
。补丁是
__getitem__()
因此数据加载器并不真正了解这些。
示例:
让我们吃以下
collate
作用
def custom_collate(batch):
return batch
它除了返回生成器之外什么都不做,由
产量
在里面
__getitem__()
。
配置
数据加载器
dataloader_train = DataLoader(dataset=dataset_train, batch_size=1, collate_fn=custom_collate)
并对其进行迭代
for sample in iter(dataloader_train):
print(sample)
给予
[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154D60>]
[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154DD0>]
[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154D60>]
[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154DD0>]
...
其中
[<generator object>]
实例等于图像路径的数量。
如果我增加
批次_大小
到
2
我会得到
[<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684D60>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684DD0>]
[<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684E40>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684EB0>]
[<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684D60>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684DD0>]
...
我想这就是批处理的工作方式,但反过来说,这不适合我,因为我可能很少有大的图像,因此会产生很多补丁。