数据集-Mindspore 25天打卡
昇思25天打卡-数据集
·
数据集Dataset
主要内容:数据集加载、数据集迭代、数据集常用操作和自定义数据集
数据集加载
from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
print(type(train_dataset))
数据集迭代
访问数据类型默认为Tensor,若output_numpy=True,方位数据类型为Numpy
def visualize(dataset):
figure = plt.figure(figsize=(4, 4))
cols, rows = 3, 3
plt.subplots_adjust(wspace=0.5, hspace=0.5)
for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
figure.add_subplot(rows, cols, idx + 1)
plt.title(int(label))
plt.axis("off")
plt.imshow(image.asnumpy().squeeze(), cmap="gray")
if idx == cols * rows - 1:
break
plt.show()
visualize(train_dataset)
数据集常用操作
shuffle
train_dataset = train_dataset.shuffle(buffer_size=64)
visualize(train_dataset)
map
# map 前
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)
# map 后
train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)
batch
train_dataset = train_dataset.batch(batch_size=32)
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)
自定义数据集
class RandomAccessDataset:
def __init__(self):
self._data = np.ones((5, 2))
self._label = np.zeros((5, 1))
def __getitem__(self, index):
return self._data[index], self._label[index]
def __len__(self):
return len(self._data)
loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])
for data in dataset:
print(data)
可迭代数据集
class IterableDataset():
def __init__(self, start, end):
'''init the class object to hold the data'''
self.start = start
self.end = end
def __next__(self):
'''iter one data and return'''
return next(self.data)
def __iter__(self):
'''reset the iter'''
self.data = iter(range(self.start, self.end))
return self
loader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])
for d in dataset:
print(d)
生成器
def my_generator(start, end):
for i in range(start, end):
yield
# since a generator instance can be only iterated once, we need to wrap it by lambda to generate multiple instances
dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])
for d in dataset:
print(d)
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐


所有评论(0)