# 这部分内容是了解数据集以及在机器学习中的用法
# mindspore.dataset提供数据集加载
# 这里主要学习不同数据集的加载方式,常见操作,自定义数据集的方法

% %%capture captured_output
# # 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
# !pip uninstall mindspore -y
# !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
 


import numpy as np
from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
import matplotlib.pyplot as plt

# 以Mnist为例,学习用mindspore.dataset加载

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

# shuffle=False 表示在加载数据时不进行随机打乱。
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
print(type(train_dataset))

# 数据加载后,一般【以迭代方式】获取数据,然后再送入net中训练,这是要用create_tuple_iterator  或者
# create_dict_iterator接口创建“数据迭代器”, 迭代访问数据
# 简单说,就是数据一条一条依次读入并用于训练
# 注意:访问的数据类型默认为Tensor。如果数据类型是Numpy,设置output_numpy=True

# 下面是迭代9张图片并可视化
def visualize(dataset):
    figure = plt.figure(figsize=(4, 4))
    cols, rows = 3, 3

    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
        figure.add_subplot(rows, cols, idx + 1)
        plt.title(int(label))
        plt.axis("off")
        plt.imshow(image.asnumpy().squeeze(), cmap="gray")
        if idx == cols * rows - 1:
            break
    plt.show()

##  下面是数据集的常用操作

# shuffle: 随机功能, 可以消除数据排列造成的分布不均问题
# mindspore.dataset提供的数据集在加载时可配置shuffle=True,或使用如下操作:
train_dataset = train_dataset.shuffle(buffer_size=64)

visualize(train_dataset)

## 继续数据集的常用操作

# map: 预处理数据,针对数据集指定列(column)添加数据变换(Transforms),将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集。

image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

# 对Mnist数据集做数据缩放处理,将图像统一除以255,数据类型由uint8转为了float32。
train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')

# 对比map前后的数据的数据类型
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

## batch操作: 将数据集打包为固定大小的batch(保证梯度下降的随机性和优化计算量)
train_dataset = train_dataset.batch(batch_size=32)

image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)


## 自定义数据集
# mindspore.dataset提供了常用的公开数据集和标准格式数据集
# 暂不支持直接加载数据集,可以构造自定义数据加载类或自定义数据集生成函数,再用GeneratorDataset接口实现数据加载。
# GeneratorDataset支持通过可随机访问数据集对象、可迭代数据集对象和生成器generator构造自定义数据集

# 可随机访问数据集,如当用dataset[idx]访问时,可以读取dataset内容中第idx个样本或label。
# Random-accessible object as input source
class RandomAccessDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))

    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)


loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])

for data in dataset:
    print(data)

# list, tuple are also supported.
loader = [np.array(0), np.array(1), np.array(2)]
dataset = GeneratorDataset(source=loader, column_names=["data"])

for data in dataset:
    print(data)

## 可迭代数据集
#  使用iter(dataset)时,可以读取从数据库,远程服务器返回的数据流
# 下面示例, 一个简单迭代器,加载到GeneratorDataset
# Iterator as input source
class IterableDataset():
    def __init__(self, start, end):
        '''init the class object to hold the data'''
        self.start = start
        self.end = end
    def __next__(self):
        '''iter one data and return'''
        return next(self.data)
    def __iter__(self):
        '''reset the iter'''
        self.data = iter(range(self.start, self.end))
        return self

loader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])

for d in dataset:
    print(d)

## 生成器: 可迭代的数据集类型
# Generator
"""
生成器的作用是为机器学习模型提供更多的数据,以增加训练集的规模和多样性,从而提高模型的泛化能力和准确性。通过使用生成器,可以生成与原始数据相似但又不完全相同的数据,以扩充数据集,减少过拟合的风险。
生成器可以基于不同的方法和技术,例如基于统计学原理、随机过程、深度学习等。一些常见的生成器包括高斯分布生成器、混合模型生成器、生成对抗网络(GAN)等。
例如,在图像识别任务中,可以使用生成器生成逼真的图像数据,以增加训练集的数量和多样性;在自然语言处理任务中
"""
def my_generator(start, end):
    for i in range(start, end):
        yield i
        
# since a generator instance can be only iterated once, we need to wrap it by lambda to generate multiple instances
dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])

for d in dataset:
    print(d)   


Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐