链接1:
处理数据样本的代码可能会变得杂乱无章且难以维护;理想情况下,我们希望数据集代码与模型训练代码分离,以获得更好的可读性和模块化。PyTorch 提供了两个数据原语: torch.utils.data.DataLoader
和 torch.utils.data.Dataset
允许您使用预加载的数据集和您自己的数据。 Dataset
存储样本及其相应的标签,而 DataLoader
则在 Dataset
周围封装了一个可迭代器,以方便访问样本。
自定义 Dataset 类必须实现三个函数:init、len 和 **getitem** 三个函数。
__init__
__len__
len 函数返回数据集中的样本数。
__getitem__
__getitem__函数从数据集中加载并返回给定索引 idx
的样本。根据索引,它确定图像在磁盘上的位置,使用 read_image
将其转换为张量图像,从 self.img_labels
中的 csv 数据中获取相应的标签,调用转换函数(如果适用),并以元组形式返回张量图像和相应的标签。
链接2:
torch.utils.data.Dataset是代表自定义数据集方法的类,用户可以通过继承该类来自定义自己的数据集类,在继承时要求用户重载__len__()和__getitem__()这两个魔法方法。
len():返回的是数据集的大小。我们构建的数据集是一个对象,而数据集不像序列类型(列表、元组、字符串)那样可以直接用len()来获取序列的长度,魔法方法__len__()的目的就是方便像序列那样直接获取对象的长度。如果A是一个类,a是类A的实例化对象,当A中定义了魔法方法__len__(),len(a)则返回对象的大小。
getitem():实现索引数据集中的某一个数据。我们知道,序列可以通过索引的方法获取序列中的任意元素,getitem()则实现了能够通过索引的方法获取对象中的任意元素。此外,我们可以在__getitem__()中实现数据预处理。
举例:
import torch
from torch.utils.data import Dataset
class TensorDataset(Dataset):
"""
TensorDataset继承Dataset, 重载了__init__(), __getitem__(), __len__()
实现将一组Tensor数据对封装成Tensor数据集
能够通过index得到数据集的数据,能够通过len,得到数据集大小
"""
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
# 生成数据
data_tensor = torch.randn(4, 3)
target_tensor = torch.rand(4)
# 将数据封装成Dataset
tensor_dataset = TensorDataset(data_tensor, target_tensor)
# 可使用索引调用数据
print(tensor_dataset[1])
# 输出:(tensor([-1.0351, -0.1004, 0.9168]), tensor(0.4977))
# 获取数据集大小
print(len(tensor_dataset))
# 输出:4
因篇幅问题不能全部显示,请点此查看更多更全内容
Copyright © 2019- awee.cn 版权所有 湘ICP备2023022495号-5
违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com
本站由北京市万商天勤律师事务所王兴未律师提供法律服务