数据集Dataset

主要内容:数据集加载、数据集迭代、数据集常用操作和自定义数据集
请添加图片描述

数据集加载

from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
print(type(train_dataset))

数据集迭代

访问数据类型默认为Tensor,若output_numpy=True,方位数据类型为Numpy

def visualize(dataset):
    figure = plt.figure(figsize=(4, 4))
    cols, rows = 3, 3

    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
        figure.add_subplot(rows, cols, idx + 1)
        plt.title(int(label))
        plt.axis("off")
        plt.imshow(image.asnumpy().squeeze(), cmap="gray")
        if idx == cols * rows - 1:
            break
    plt.show()
visualize(train_dataset)

数据集常用操作

shuffle

train_dataset = train_dataset.shuffle(buffer_size=64)

visualize(train_dataset)

map

# map 前
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)
# map 后
train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

batch

train_dataset = train_dataset.batch(batch_size=32)
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

自定义数据集

class RandomAccessDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))

    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)

loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])

for data in dataset:
    print(data)

可迭代数据集

class IterableDataset():
    def __init__(self, start, end):
        '''init the class object to hold the data'''
        self.start = start
        self.end = end
    def __next__(self):
        '''iter one data and return'''
        return next(self.data)
    def __iter__(self):
        '''reset the iter'''
        self.data = iter(range(self.start, self.end))
        return self

loader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])

for d in dataset:
    print(d)

生成器


def my_generator(start, end):
    for i in range(start, end):
        yield
        
# since a generator instance can be only iterated once, we need to wrap it by lambda to generate multiple instances
dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])

for d in dataset:
    print(d)
Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐