代码之家  ›  专栏  ›  技术社区  ›  Rocket

错误:default_collate:批处理必须包含张量、numpy数组、数字、dicts或列表;找到的对象

  •  0
  • Rocket  · 技术社区  · 2 年前

    我正在创建一个带有数据集的数据加载器。

    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=True)
    

    数据集每个索引包含2个元素:一个矩阵和一个数组。

    a, b = train_ds.__getitem__(0)
    print(type(a))
    print(a)
    print(type(b))
    print(b)
    

    退货:

    <class 'numpy.ndarray'>
    
    [[ 44.9329    46.08967   44.9329   ...  99.2188    99.735664  99.17029 ]
     [ 44.9329    44.9329    44.9329   ... 114.164474 114.292244 114.11395 ]
     [ 44.9329    45.03071   44.9329   ... 114.57378  114.56599  114.49552 ]
     ...
     [ 44.9329    44.9329    44.9329   ...  52.242996  50.12293   44.9329  ]
     [ 44.9329    44.9329    44.9329   ...  44.9329    44.9329    44.9329  ]
     [ 44.9329    44.9329    44.9329   ...  44.9329    44.9329    44.9329  ]]
    
    <class 'numpy.ndarray'>
    
    [-0.002963486 -0.003033393 0.00371422 2.02e-06 0.004402838 -0.002704915
     0.003289625 -0.002551801 -0.003632823 -0.003408553 -0.002707387
     0.00278949 0.000828761 0.000849513 0.003992096 -0.002692624 0.001183484
     9.43e-05 0.003836168 2.24e-05 0.003944455 -0.001950883 -0.000877485
     0.001734729 -0.003225849 -0.000537016 6.53e-05 -0.003643878 -0.002444321
     0.002499692 0.001538219 0.002263657 0.003073046 0.004134932 -0.002500862
     -0.001662471 0.002273667 0.00375025 0.001866289 -0.002027481 0.002197658
     -0.002243473 0.000943156 -0.000643054 -0.003169563 -0.003424202
     0.00118924 -0.003570424 0.002273526]
    

    但当尝试用以下内容迭代我的数据加载器时:

    for i, data in enumerate(train_dl):
    

    我得到错误:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
     in 
          1 num_epochs=1   # Just for demo, adjust this higher.
    ----> 2 training(myModel, train_dl, num_epochs)
    
     in training(model, train_dl, num_epochs)
         18 
         19     # Repeat for each batch in the training set
    ---> 20     for i, data in enumerate(train_dl):
         21         # Get the input features and target labels, and put them on the GPU
         22         inputs, labels = data[0].to(device), data[1].to(device)
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\dataloader.py in __next__(self)
        631                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
        632                 self._reset()  # type: ignore[call-arg]
    --> 633             data = self._next_data()
        634             self._num_yielded += 1
        635             if self._dataset_kind == _DatasetKind.Iterable and \
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\dataloader.py in _next_data(self)
        675     def _next_data(self):
        676         index = self._next_index()  # may raise StopIteration
    --> 677         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        678         if self._pin_memory:
        679             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
         52         else:
         53             data = self.dataset[possibly_batched_index]
    ---> 54         return self.collate_fn(data)
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
        263             >>> default_collate(batch)  # Handle `CustomType` automatically
        264     """
    --> 265     return collate(batch, collate_fn_map=default_collate_fn_map)
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate(batch, collate_fn_map)
        140 
        141         if isinstance(elem, tuple):
    --> 142             return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
        143         else:
        144             try:
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in (.0)
        140 
        141         if isinstance(elem, tuple):
    --> 142             return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
        143         else:
        144             try:
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate(batch, collate_fn_map)
        117     if collate_fn_map is not None:
        118         if elem_type in collate_fn_map:
    --> 119             return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
        120 
        121         for collate_type in collate_fn_map:
    
    ~\AppData\Roaming\Python\Python38\site-packages\torch\utils\data\_utils\collate.py in collate_numpy_array_fn(batch, collate_fn_map)
        167     # array of string classes and object
        168     if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
    --> 169         raise TypeError(default_collate_err_msg_format.format(elem.dtype))
        170 
        171     return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
    
    TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
    

    为什么它抛出错误并返回它只找到了一个对象?

    0 回复  |  直到 2 年前
        1
  •  0
  •   Greg Uretzky    2 年前

    您遇到的错误是因为PyTorch的默认整理功能( default_collate )不知道如何处理数据集中的自定义对象。

    错误消息“TypeError:default_collate:batch必须包含张量、numpy数组、数字、dicts或列表;已找到对象”指示 default_collate 应为列出的类型之一,但找到了不同的类型(“对象”)。

    在您的案例中,数据集中有numpy数组,DataLoader会尝试将这些数组整理(组合)成一批。PyTorch的DataLoader希望数据采用PyTorch张量的形式。因此,在将数据发送到DataLoader之前,您应该将numpy数组转换为PyTorch张量。

    以下是解决问题的方法:

    1. 将numpy数组转换为数据集中的PyTorch张量 __getitem__ 方法
    2. 从返回数据时使用这些张量 __getitem__ 方法

    以下是您的数据集的修改版本 __getitem__ 可能看起来像:

    import torch
    
    class YourDataset(torch.utils.data.Dataset):
        def __init__(self, ...):  # your other initialization arguments
            ...
    
        def __getitem__(self, index):
            a, b = ...  # however you're getting your numpy arrays currently
            a_tensor = torch.from_numpy(a)
            b_tensor = torch.from_numpy(b)
            return a_tensor, b_tensor
    

    通过这样做,您可以确保无论何时访问数据集中的项,它都已经是张量形式的,DataLoader可以毫无问题地处理它。