欢迎大家加入MindSpore社区一起玩!

1.ICT概念与基本原理简介

图像修复技术(ICT)的核心目标是通过算法重建图像中缺失或损坏的部分,使其在视觉上与原始图像保持自然和谐。这一过程结合了传统图像处理与深度学习技术,形成了多层次的实现框架。传统方法如基于扩散的修复,通过分析缺失区域边缘的像素梯度与颜色信息,逐步向内部填充内容。这种方法类似于在画布上从已知区域向未知区域延伸纹理,适用于小面积划痕或斑点修复。而纹理合成技术则通过采样图像其他区域的纹理特征,将相似结构拼接到缺失位置,尤其擅长处理大面积纹理缺失,如修复老照片中的磨损区域。

深度学习的引入彻底改变了图像修复的范式。卷积神经网络(CNN)通过多层卷积核捕捉图像的局部特征,结合编码器-解码器结构,能够从完整图像中学习特征映射关系。例如,当修复一张被遮挡的人脸图像时,CNN会先提取面部轮廓、眼睛、鼻子等关键特征,再根据这些特征生成缺失部分的像素值。生成对抗网络(GAN)则通过生成器与判别器的博弈机制,使修复结果在细节真实性和纹理连续性上达到更高水平。生成器负责生成填补内容,判别器则判断其是否与周围像素融合自然,二者通过对抗训练不断优化,最终生成难以区分真伪的修复效果。

Transformer模型通过自注意力机制,突破了传统卷积网络的局部感知限制。它能够捕捉图像中长距离的依赖关系,例如在修复建筑图片时,模型会同时参考远处窗户的排列规律和近处砖墙的纹理方向,确保修复区域在几何结构和纹理风格上与整体保持一致。这种特性使得Transformer在处理复杂场景时更具优势,如修复包含周期性纹理的织物或具有透视关系的建筑立面。

图像修复技术的应用场景广泛且深入。在文物保护领域,技术能够还原古代壁画因岁月侵蚀产生的裂痕与褪色,通过分析残留颜料分布与笔触方向,重建缺失部分的原始风貌。医学影像处理中,修复算法可去除CT或MRI图像中的噪声与运动伪影,甚至填补因患者移动导致的缺失切片,辅助医生进行更精准的诊断。视频修复则需考虑时间维度的一致性,通过帧间插值与运动补偿技术,确保修复后的视频在播放时不会出现闪烁或跳跃现象。

尽管技术取得显著进展,仍面临多重挑战。高分辨率图像修复需要消耗大量计算资源,现有模型难以在移动端实现实时处理。复杂结构修复,如同时包含规则纹理与不规则形状的缺失区域,常因特征冲突导致模糊或断裂。视频修复还需解决帧间连贯性问题,避免生成内容在时间轴上产生割裂感。未来发展方向包括融合多模态信息(如结合文本描述引导修复方向)、利用神经渲染技术处理三维场景,以及通过模型压缩与量化技术推动轻量化部署,使图像修复技术能够更广泛地应用于消费级场景。

其中大致流程如图所示:

2.训练数据集准备与环境准备

2.1 数据集

本实验需要使用两个数据集,其一是ImageNet数据集,下载地址:ImageNet,其二是掩码修复数据集,下载地址:www.dropbox.com

2.2 参数设置

其中训练使用的代码参数设置如下面代码所示

import os
import argparse

import mindspore
import numpy as np

from mindspore import context


def parse_args():
    """Parse args."""
    parser = argparse.ArgumentParser()
    # Parameter of train
    parser.add_argument('--data_path', type=str, default='/data0/imagenet2012/train',
                        help='Indicate where is the training set')
    parser.add_argument('--mask_path', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
                        help='Where is the mask')
    parser.add_argument('--n_layer', type=int, default=35)
    parser.add_argument('--n_head', type=int, default=8)
    parser.add_argument('--n_embd', type=int, default=1024)
    parser.add_argument('--GELU_2', action='store_true', help='use the new activation function')
    parser.add_argument('--use_ImageFolder', action='store_true', help='Using the original folder for ImageNet dataset')
    parser.add_argument('--random_stroke', action='store_true', help='Use the generated mask')
    parser.add_argument('--train_epoch', type=int, default=5, help='How many epochs')
    parser.add_argument('--learning_rate', type=float, default=3e-4, help='Value of learning rate.')
    parser.add_argument('--input', type=str, default='/data0/imagenet2012/train',
                        help='path to the input images directory or an input image')
    parser.add_argument('--mask', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
                        help='path to the masks directory or a mask file')
    parser.add_argument('--prior', type=str, default='', help='path to the edges directory or an edge file')
    parser.add_argument('--kmeans', type=str, default='../kmeans_centers.npy', help='path to the kmeans')
    # 根据VGG19.ckpt的路径,要进行相应修改,VGG19权重文件可在推理部分给出的权重地址进行下载
    parser.add_argument('--vgg_path', type=str, default='../ckpts_ICT/VGG19.ckpt', help='path to the VGG')
    parser.add_argument('--image_size', type=int, default=256, help='the size of origin image')
    parser.add_argument('--prior_random_degree', type=int, default=1, help='during training, how far deviate from')
    parser.add_argument('--use_degradation_2', action='store_true', help='use the new degradation function')
    parser.add_argument('--mode', type=int, default=1, help='1:train, 2:test')
    parser.add_argument('--mask_type', type=int, default=2)
    parser.add_argument('--max_iteration', type=int, default=25000, help='How many run iteration')
    parser.add_argument('--lr', type=float, default=0.0001, help='Value of learning rate.')
    parser.add_argument('--D2G_lr', type=float, default=0.1,
                        help='Value of discriminator/generator learning rate ratio')
    parser.add_argument("--l1_loss_weight", type=float, default=1.0)
    parser.add_argument("--style_loss_weight", type=float, default=250.0)
    parser.add_argument("--content_loss_weight", type=float, default=0.1)
    parser.add_argument("--inpaint_adv_loss_weight", type=float, default=0.1)
    parser.add_argument('--ckpt_path', type=str, default='', help='model checkpoints path')
    parser.add_argument('--save_path', type=str, default='./checkpoint', help='save checkpoints path')
    parser.add_argument('--device_id', type=int, default='0')
    parser.add_argument('--device_target', type=str, default='GPU', help='GPU or Ascend')
    parser.add_argument('--prior_size', type=int, default=32, help='Input sequence length = prior_size*prior_size')
    parser.add_argument('--batch_size', type=int, default=2, help='The number of train batch size')
    parser.add_argument("--beta1", type=float, default=0.9, help="Value of beta1")
    parser.add_argument("--beta2", type=float, default=0.95, help="Value of beta2")

    # Parameter of infer
    parser.add_argument('--image_url', type=str, default='/data0/imagenet2012/val', help='the folder of image')
    parser.add_argument('--mask_url', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
                        help='the folder of mask')
    parser.add_argument('--top_k', type=int, default=40)
    parser.add_argument('--save_url', type=str, default='./sample', help='save the output results')
    parser.add_argument('--condition_num', type=int, default=1, help='Use how many BERT output')
    args = parser.parse_known_args()[0]
    return args


# 选择执行模式为动态图模式,执行硬件平台为GPU
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
opts = parse_args()

3.训练评估

3.1 训练

由于Transformer计算量偏大,因此在图片输入网络前会先对其进行下采样,在ImageNet数据中,图片会被采样到32*32分辨率,以得到图像低分辨率的图像先验信息。Transformer模型代码构建代码:

import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore.common.initializer import Normal


class CausalSelfAttention(nn.Cell):
    """
    The CausalSelfAttention part of transformer.

    Args:
        n_embd (int): The size of the vector space in which words are embedded.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        # key, query, value projections for all heads
        self.key = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                            weight_init=Normal(sigma=0.02, mean=0.0))
        self.query = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                              weight_init=Normal(sigma=0.02, mean=0.0))
        self.value = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                              weight_init=Normal(sigma=0.02, mean=0.0))
        # regularization
        self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_pdrop)
        self.resid_drop = nn.Dropout(keep_prob=1.0 - resid_pdrop)
        # output projection
        self.proj = nn.Dense(in_channels=n_embd, out_channels=n_embd,
                             weight_init=Normal(sigma=0.02, mean=0.0))

        tril = nn.Tril()
        self.mask = mindspore.Parameter(
            tril(P.Ones()((block_size, block_size), mindspore.float32)).view(1, 1, block_size, block_size),
            requires_grad=False)

        self.n_head = n_head

    def construct(self, x):
        B, T, C = P.Shape()(x)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head)
        k = k.transpose(0, 2, 1, 3)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head)
        q = q.transpose(0, 2, 1, 3)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head)
        v = v.transpose(0, 2, 1, 3)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        k_shape = k.shape[-1]
        sz = 1.0 / (Tensor(k_shape) ** Tensor(0.5))
        q = mindspore.ops.Cast()(q, mindspore.float16)
        k = mindspore.ops.Cast()(k, mindspore.float16)
        sz = mindspore.ops.Cast()(sz, mindspore.float16)
        att = (ops.matmul(q, k.transpose(0, 1, 3, 2)) * sz)
        att = mindspore.ops.Cast()(att, mindspore.float32)
        att = P.Softmax()(att)
        att = self.attn_drop(att)
        att = mindspore.ops.Cast()(att, mindspore.float16)
        v = mindspore.ops.Cast()(v, mindspore.float16)
        y = mindspore.ops.matmul(att, v)
        y = mindspore.ops.Cast()(y, mindspore.float32)
        y = y.transpose(0, 2, 1, 3).view(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = self.resid_drop(self.proj(y))
        return y


class GELU2(nn.Cell):
    """
    The new gelu2 activation function.

    Returns:
        Tensor, output tensor.
    """

    def construct(self, x):
        return x * P.Sigmoid()(1.702 * x)


class Block_2(nn.Cell):
    """
    Transformer block with original GELU2.

    Args:
        n_embd (int): The size of the vector space in which words are embedded.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.ln2 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, resid_pdrop, attn_pdrop)
        self.mlp = nn.SequentialCell([
            nn.Dense(in_channels=n_embd, out_channels=4 * n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            GELU2(),
            nn.Dense(in_channels=4 * n_embd, out_channels=n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            nn.Dropout(keep_prob=1.0 - resid_pdrop),
        ])

    def construct(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class Block(nn.Cell):
    """
    Transformer block with original GELU.

    Args:
        n_embd (int): The size of the vector space in which words are embedded.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.ln2 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, resid_pdrop, attn_pdrop)
        self.mlp = nn.SequentialCell([
            nn.Dense(in_channels=n_embd, out_channels=4 * n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            nn.GELU(),
            nn.Dense(in_channels=4 * n_embd, out_channels=n_embd,
                     weight_init=Normal(sigma=0.02, mean=0.0)),
            nn.Dropout(keep_prob=1.0 - resid_pdrop),
        ])

    def construct(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPT(nn.Cell):
    """
    The full GPT language model, with a context size of block_size.

    Args:
        vocab_size (int): The size of the vocabulary in the embedded data.
        n_embd (int): The size of the vector space in which words are embedded.
        n_layer (int): The number of attention layer.
        n_head (int): The number of multi-head.
        block_size (int): The context size(Input sequence length).
        use_gelu2 (bool): Use the new gelu2 activation function.
        embd_pdrop (float): The probability of embd_pdrop. Default: 0.1
        resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
        attn_pdrop (float): The probability of attn_pdrop. Default: 0.1

    Returns:
        Tensor, output tensor.
    """

    def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, block_size: int, use_gelu2: bool,
                 embd_pdrop: float = 0.1, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, n_embd, embedding_table=Normal(sigma=0.02, mean=0.0))
        self.pos_emb = mindspore.Parameter(P.Zeros()((1, block_size, n_embd), mindspore.float32))
        self.drop = nn.Dropout(keep_prob=1.0 - embd_pdrop)
        # transformer
        if use_gelu2:
            self.blocks = nn.SequentialCell(
                [*[Block_2(n_embd, n_head, block_size, resid_pdrop, attn_pdrop) for _ in range(n_layer)]])
        else:
            self.blocks = nn.SequentialCell(
                [*[Block(n_embd, n_head, block_size, resid_pdrop, attn_pdrop) for _ in range(n_layer)]])
        # decoder head
        self.ln_f = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
        self.head = nn.Dense(in_channels=n_embd, out_channels=vocab_size, has_bias=False,
                             weight_init=Normal(sigma=0.02, mean=0.0))

        self.block_size = block_size

    def get_block_size(self):
        return self.block_size

    def construct(self, idx, masks):
        _, t = idx.shape
        token_embeddings = self.tok_emb(idx)  # each index maps to a (learnable) vector
        masks = P.ExpandDims()(masks, 2)
        token_embeddings = token_embeddings * (1 - masks)
        position_embeddings = self.pos_emb[:, :t, :]  # each position maps to a (learnable) vector
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits

 损失函数定义

class TransformerWithLoss(nn.Cell):
    """
    Wrap the network with loss function to return Transformer with loss.

    Args:
        backbone (Cell): The target network to wrap.
    """

    def __init__(self, backbone):
        super(TransformerWithLoss, self).__init__(auto_prefix=False)
        self.backbone = backbone
        self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)

    def construct(self, x, targets, masks):
        logits = self.backbone(x, masks)
        loss = self.loss_fn(logits.view(-1, logits.shape[-1]), targets.view(-1))
        masks = P.ExpandDims()(masks, 2)
        masks = masks.view(-1)
        loss *= masks
        loss = P.ReduceMean()(loss)
        return loss

模型定义与数据集导入

# 修改当前路径至ICT/Transformer
if os.getcwd().endswith('ICT') or os.getcwd().endswith('ict'):
    os.chdir("./Transformer")

# 启动ImageNet数据集选项
opts.use_ImageFolder = True

from datasets.dataset import load_dataset
from transformer_utils.util import AverageMeter

kmeans = np.load('../kmeans_centers.npy')
kmeans = np.rint(127.5 * (kmeans + 1.0))

# Define the dataset
train_dataset = load_dataset(opts.data_path, kmeans, mask_path=opts.mask_path, is_train=True,
                             use_imagefolder=opts.use_ImageFolder, prior_size=opts.prior_size,
                             random_stroke=opts.random_stroke)
train_dataset = train_dataset.batch(opts.batch_size)
step_size = train_dataset.get_dataset_size()

# Define the model
block_size = opts.prior_size * opts.prior_size
transformer = GPT(vocab_size=kmeans.shape[0], n_embd=opts.n_embd, n_layer=opts.n_layer, n_head=opts.n_head,
                  block_size=block_size, use_gelu2=opts.GELU_2, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0)
model = TransformerWithLoss(backbone=transformer)

# Define the optimizer
optimizer = nn.Adam(model.trainable_params(), learning_rate=opts.learning_rate, beta1=opts.beta1, beta2=opts.beta2)

train_net = nn.TrainOneStepCell(model, optimizer)

然后开始训练,此处设置100轮epoch进行训练

在得到低维图像重建先验后,我们需要学习一个确定性映射将其重新缩放为原始分辨率,由于卷积神经网络在建模纹理方面具有优势,我们引入了一种基于卷积神经网络的引导上采样网络,它可以在掩码输入的指导下渲染重建的外观先验的高保真细节,上采样网络生成器代码实现如下

class Generator(nn.Cell):
    def __init__(self, residual_blocks=8):
        super(Generator, self).__init__()

        self.encoder = nn.SequentialCell(
            nn.Pad(((0, 0), (0, 0), (3, 3), (3, 3)), mode='REFLECT'),
            nn.Conv2d(in_channels=6, out_channels=64, kernel_size=7, pad_mode='pad', padding=0, has_bias=True),
            nn.ReLU(),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                      has_bias=True),
            nn.ReLU(),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                      has_bias=True),
            nn.ReLU()
        )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock_remove_IN(256, 2)
            blocks.append(block)
        self.middle = nn.SequentialCell(*blocks)
        self.decoder = nn.SequentialCell(
            nn.Conv2dTranspose(in_channels=256, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                               has_bias=True),
            nn.ReLU(),

            nn.Conv2dTranspose(in_channels=128, out_channels=64, kernel_size=4, stride=2, pad_mode='pad', padding=1,
                               has_bias=True),
            nn.ReLU(),

            nn.Pad(((0, 0), (0, 0), (3, 3), (3, 3)), mode='REFLECT'),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, pad_mode='pad', padding=0, has_bias=True),
        )

    def construct(self, images, edges, masks):
        images_masked = (images * P.Cast()((1 - masks), mindspore.float32)) + masks
        x = P.Concat(axis=1)((images_masked, edges))
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = (P.Tanh()(x) + 1) / 2

        return x


class ResnetBlock_remove_IN(nn.Cell):
    def __init__(self, dim, dilation=1):
        super(ResnetBlock_remove_IN, self).__init__()
        self.conv_block = nn.SequentialCell([
            nn.Pad(((0, 0), (0, 0), (dilation, dilation), (dilation, dilation)), mode='REFLECT'),
            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, pad_mode='pad', dilation=dilation,
                      has_bias=True),
            nn.ReLU(),

            nn.Pad(((0, 0), (0, 0), (1, 1), (1, 1)), mode='SYMMETRIC'),
            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, pad_mode='pad', dilation=1, has_bias=True)
        ])

    def construct(self, x):
        out = x + self.conv_block(x)
        return out

然后定义数据集和网络,并且加入损失函数定义

# 修改当前路径至ICT/Guided_Upsample
if os.getcwd().endswith('Transformer'):
    os.chdir("../Guided_Upsample")
elif os.getcwd().endswith('ICT') or os.getcwd().endswith('ict'):
    os.chdir("./Guided_Upsample")

from Guided_Upsample.datasets.dataset import load_dataset
from Guided_Upsample.models.loss import GeneratorWithLoss, DiscriminatorWithLoss

train_dataset = load_dataset(image_flist=opts.input, edge_flist=opts.prior, mask_filst=opts.mask,
                             image_size=opts.image_size, prior_size=opts.prior_size, mask_type=opts.mask_type,
                             kmeans=opts.kmeans, use_degradation_2=opts.use_degradation_2,
                             prior_random_degree=opts.prior_random_degree,
                             augment=True, training=True)
train_dataset = train_dataset.batch(opts.batch_size)
step_size = train_dataset.get_dataset_size()

generator = Generator()
discriminator = Discriminator(in_channels=3)
model_G = GeneratorWithLoss(generator, discriminator, opts.vgg_path, opts.inpaint_adv_loss_weight, opts.l1_loss_weight,
                            opts.content_loss_weight, opts.style_loss_weight)
model_D = DiscriminatorWithLoss(generator, discriminator)

class PSNR(nn.Cell):
    def __init__(self, max_val):
        super(PSNR, self).__init__()

        base10 = P.Log()(mindspore.Tensor(10.0, mindspore.float32))
        max_val = P.Cast()(mindspore.Tensor(max_val), mindspore.float32)

        self.base10 = mindspore.Parameter(base10, requires_grad=False)
        self.max_val = mindspore.Parameter(20 * P.Log()(max_val) / base10, requires_grad=False)

    def __call__(self, a, b):
        a = P.Cast()(a, mindspore.float32)
        b = P.Cast()(b, mindspore.float32)
        mse = P.ReduceMean()((a - b) ** 2)

        if mse == 0:
            return mindspore.Tensor(0)

        return self.max_val - 10 * P.Log()(mse) / self.base10


psnr_func = PSNR(255.0)

# Define the optimizer
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=opts.lr, beta1=opts.beta1, beta2=opts.beta2)
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=opts.lr * opts.D2G_lr, beta1=opts.beta1,
                      beta2=opts.beta2)

然后开始训练

3.2 推理

训练完毕后开始推理使用,共分为两个阶段,第一阶段是Transformer模型推理生成部分,代码如下:

from mindspore.train import Model

# Define the model
block_size = opts.prior_size * opts.prior_size
transformer = GPT(vocab_size=C.shape[0], n_embd=opts.n_embd, n_layer=opts.n_layer, n_head=opts.n_head,
                  block_size=block_size, use_gelu2=opts.GELU_2, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Transformer/ImageNet_best.ckpt'
if os.path.exists(opts.ckpt_path):
    print('Start loading the model parameters from %s' % (opts.ckpt_path))
    checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
    mindspore.load_param_into_net(transformer, checkpoint)
    print('Finished load the model')
transformer.set_train(False)
model = Model(transformer)

第二阶段是上采样部分,代码如下:

def stitch_images(inputs, *outputs, img_per_row=2):
    gap = 5
    columns = len(outputs) + 1

    width, height = inputs[0][:, :, 0].shape
    img = Image.new('RGB',
                    (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
    images = [inputs, *outputs]

    for ix in range(len(inputs)):
        xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
        yoffset = int(ix / img_per_row) * height

        for cat in range(len(images)):
            im = images[cat][ix].asnumpy().astype(np.uint8).squeeze()
            im = Image.fromarray(im)
            img.paste(im, (xoffset + cat * width, yoffset))

    return img


from upsample_utils.util import postprocess, imsave

opts.mask_type = 3
opts.mode = 2
opts.input = './input/'
opts.kmeans = './kmeans_centers.npy'
# 确保第二阶段的输入与第一阶段的输出相同
opts.prior = opts.save_url

generator = Generator()
generator.set_train(False)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Upsample/InpaintingModel_gen_best.ckpt'
if os.path.exists(opts.ckpt_path):
    print('Start loading the model parameters from %s' % (opts.ckpt_path))
    checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
    mindspore.load_param_into_net(generator, checkpoint)
    print('Finished load the model')

psnr_func = PSNR(255.0)

test_dataset = load_dataset(image_flist=opts.input, edge_flist=opts.prior, mask_filst=opts.mask,
                            image_size=opts.image_size, prior_size=opts.prior_size, mask_type=opts.mask_type,
                            kmeans=opts.kmeans, condition_num=opts.condition_num,
                            augment=False, training=False)

index = 0
psnr = AverageMeter()
mae = AverageMeter()
test_batch_size = 1
test_dataset = test_dataset.batch(test_batch_size)
for sample in test_dataset.create_dict_iterator():
    name = sample['name'].asnumpy()[0]
    images = sample['images']
    edges = sample['edges']
    masks = sample['masks']
    inputs = (images * (1 - masks)) + masks
    index += test_batch_size
    outputs = generator(images, edges, masks)
    outputs_merged = (outputs * masks) + (images * (1 - masks))
    psnr.update(psnr_func(postprocess(images), postprocess(outputs_merged)), 1)
    mae.update((P.ReduceSum()(P.Abs()(images - outputs_merged)) / P.ReduceSum()(images)), 1)
    result_merge = stitch_images(
        postprocess(images),
        postprocess(inputs),
        postprocess(outputs_merged),
        img_per_row=1
    )
    result_merge.show()
    output = postprocess(outputs_merged)[0]
    path = os.path.join(opts.save_url, name[:-4] + "_%d" % (index % opts.condition_num) + '.png')
    imsave(output, path)
print('PSNR: {}, MAE: {}'.format(psnr.avg, mae.avg))

最终执行推理效果如下图所示,可以看到对图像添加掩码扰动,将原图、输入以及结果做了处理后统一输出,第一张图为未被损坏的原图,第二张图则是经过掩码处理后的图片,相当于输入,第三张图是网络的输出,即图像补全的结果。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐