什么是FCN模型?

因为模型网络中所有的层都是卷积层,故称为全卷积网络。

全卷积神经网络主要使用了三种技术:

  1. 卷积化(Convolutional)
  2. 上采样(Upsample)
  3. 跳跃结构(Skip Layer)

语义分割的目的是对图像中每个像素点进行分类。与普通的分类任务只输出某个类别不同,语义分割任务输出与输入大小相同的图像,输出图像的每个像素对应了输入图像每个像素的类别。

数据处理

数据预处理

由于PASCAL VOC 2012数据集中图像的分辨率大多不一致,无法放在一个tensor中,故输入前需做标准化处理。

数据加载

将PASCAL VOC 2012数据集与SDB数据集进行混合。

import cv2
import mindspore.dataset as ds
import numpy as np


class SegDataset:
    def __init__(
        self,
        image_mean,
        image_std,
        data_file="",
        batch_size=32,
        crop_size=512,
        max_scale=2.0,
        min_scale=0.5,
        ignore_label=255,
        num_classes=21,
        num_readers=2,
        num_parallel_calls=4,
    ):

        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale

    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(
            np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE
        )
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(
            label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST
        )

        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(
                image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0
            )
            label_out = cv2.copyMakeBorder(
                label_out,
                0,
                pad_h,
                0,
                pad_w,
                cv2.BORDER_CONSTANT,
                value=self.ignore_label,
            )
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[
            offset_h : offset_h + self.crop_size,
            offset_w : offset_w + self.crop_size,
            :,
        ]
        label_out = label_out[
            offset_h : offset_h + self.crop_size, offset_w : offset_w + self.crop_size
        ]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out

    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(
            self.data_file,
            columns_list=["data", "label"],
            shuffle=True,
            num_parallel_workers=self.num_readers,
        )
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(
            operations=transforms_list,
            input_columns=["data", "label"],
            output_columns=["data", "label"],
            num_parallel_workers=self.num_parallel_calls,
        )
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset


# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

# 实例化Dataset
dataset = SegDataset(
    image_mean=IMAGE_MEAN,
    image_std=IMAGE_STD,
    data_file=DATA_FILE,
    batch_size=train_batch_size,
    crop_size=crop_size,
    max_scale=max_scale,
    min_scale=min_scale,
    ignore_label=ignore_label,
    num_classes=num_classes,
    num_readers=2,
    num_parallel_calls=4,
)

dataset = dataset.get_dataset()

训练集可视化

运行以下代码观察载入的数据集图片(数据处理过程中已做归一化处理)。

import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(16, 8))

# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
    # 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

网络构建

网络流程

FCN网络的流程如下图所示:

  1. 输入图像image,经过pool1池化后,尺寸变为原始尺寸的1/2。
  2. 经过pool2池化,尺寸变为原始尺寸的1/4。
  3. 接着经过pool3、pool4、pool5池化,大小分别变为原始尺寸的1/8、1/16、1/32。
  4. 经过conv6-7卷积,输出的尺寸依然是原图的1/32。
  5. FCN-32s是最后使用反卷积,使得输出图像大小与输入图像相同。
  6. FCN-16s是将conv7的输出进行反卷积,使其尺寸扩大两倍至原图的1/16,并将其与pool4输出的特征图进行融合,后通过反卷积扩大到原始尺寸。
  7. FCN-8s是将conv7的输出进行反卷积扩大4倍,将pool4输出的特征图反卷积扩大2倍,并将pool3输出特征图拿出,三者融合后通反卷积扩大到原始尺寸。
import mindspore.nn as nn


class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=3,
                out_channels=64,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64,
                out_channels=64,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=128,
                out_channels=128,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=128,
                out_channels=256,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=256,
                out_channels=256,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=256,
                out_channels=256,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=256,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=512,
                out_channels=4096,
                kernel_size=7,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=4096,
                out_channels=4096,
                kernel_size=1,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(
            in_channels=4096,
            out_channels=self.n_class,
            kernel_size=1,
            weight_init="xavier_uniform",
        )
        self.upscore2 = nn.Conv2dTranspose(
            in_channels=self.n_class,
            out_channels=self.n_class,
            kernel_size=4,
            stride=2,
            weight_init="xavier_uniform",
        )
        self.score_pool4 = nn.Conv2d(
            in_channels=512,
            out_channels=self.n_class,
            kernel_size=1,
            weight_init="xavier_uniform",
        )
        self.upscore_pool4 = nn.Conv2dTranspose(
            in_channels=self.n_class,
            out_channels=self.n_class,
            kernel_size=4,
            stride=2,
            weight_init="xavier_uniform",
        )
        self.score_pool3 = nn.Conv2d(
            in_channels=256,
            out_channels=self.n_class,
            kernel_size=1,
            weight_init="xavier_uniform",
        )
        self.upscore8 = nn.Conv2dTranspose(
            in_channels=self.n_class,
            out_channels=self.n_class,
            kernel_size=16,
            stride=8,
            weight_init="xavier_uniform",
        )

    def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out

训练准备

导入VGG-16部分预训练权重

FCN使用VGG-16作为骨干网络,用于实现图像编码。使用下面代码导入VGG-16预训练模型的部分预训练权重。

from download import download
from mindspore import load_checkpoint, load_param_into_net

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)


def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)

模型训练

导入VGG-16预训练参数后,实例化损失函数、优化器,使用Model接口编译网络,训练FCN-8s网络。

import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.train import (
    CheckpointConfig,
    LossMonitor,
    Model,
    ModelCheckpoint,
    TimeMonitor,
)

device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)

train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

lr_scheduler = mindspore.nn.cosine_decay_lr(
    min_lr, base_lr, total_step, iters_per_epoch, decay_epoch=2
)
lr = Tensor(lr_scheduler[-1])

# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(
    params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001
)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":
    model = Model(
        net,
        loss_fn=loss,
        optimizer=optimizer,
        loss_scale_manager=loss_scale_manager,
        metrics={
            "pixel accuracy": PixelAccuracy(),
            "mean pixel accuracy": PixelAccuracyClass(),
            "mean IoU": MeanIntersectionOverUnion(),
            "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(),
        },
    )
else:
    model = Model(
        net,
        loss_fn=loss,
        optimizer=optimizer,
        metrics={
            "pixel accuracy": PixelAccuracy(),
            "mean pixel accuracy": PixelAccuracyClass(),
            "mean IoU": MeanIntersectionOverUnion(),
            "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(),
        },
    )

# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(
    save_checkpoint_steps=10, keep_checkpoint_max=keep_checkpoint_max
)
ckpt_callback = ModelCheckpoint(prefix="FCN8s", directory="./ckpt", config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐