It maybe easy to understand pix to pix if you have known DCGAN.

data.

url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar'
download(url, './dataset', kind = 'tar', replace = True)

#visualize them
dataset =ds.MindDataset('./dataset/dataset_pix2pix/train.mindrecord',columns_list=['input_images','target_images'],shuffle = True)
data_iter = next(dataset.create_dict_iterator(output_numpy = True))
plt.figure(figsize = (10,3), dpi = 140)
for i,image in enumerate(data_iter['input_images'][:10],1):
    plt.subplot(3, 10, i)
    plt.axis('off')
    plt.imshow((image.transpose(1,2,0) + 1)/2)
plt.show()

we get some basic elements of pix2pix. It is now replaced by dcgan, so the principle is what we should focus.

class UnetSkipConnectionBlock(nn.Cell):
    def __init__(self,outer_nc,inner_nc, inplanes=None, dropout = False,
                    submodule = None, outermost=False, innermost=False,alpha=0.2,norm_mode = 'batch'):
        super(UnetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.InstanceNorm2d(inner_nc)
            up_norm = nn.InstanceNorm2d(outer_nc)
            use_bias = True
        if inplanes is None:
            inplanes = outer_nc
        down_conv = nn.Conv2d(inplanes, inner_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias,
                              pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias,
                                pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv]
            model = down + [submodule] + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias,pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            model = down+[submodule]+up
            if dropout:
                model.append(nn.Dropout(0.5))
        def construct(self, x):
            out = self.model(x)
            if self.skip_connections:
                out = ops.concat((out,x ),axis = 1)
            return out
class UnetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf = 64, n_layers = 8, norm_mode = 'bn', dropout = False):
        super(UnetGenerator, self).__init__()
        unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=None, innermost=True, norm_mode=norm_mode)
        for _ in range(n_layers-5):
            unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=unet_block, norm_mode=norm_mode, dropout=dropout)
        unet_block = UnetSkipConnectionBlock(ngf*4, ngf*8, submodule=unet_block, norm_mode=norm_mode)
        unet_block = UnetSkipConnectionBlock(ngf*2, ngf*4, submodule=unet_block, norm_mode=norm_mode)
        unet_block = UnetSkipConnectionBlock(ngf, ngf*2, submodule=unet_block, norm_mode=norm_mode)
        unet_block = UnetSkipConnectionBlock(out_planes, ngf, inplanes=in_planes, submodule=unet_block, outermost=True, norm_mode=norm_mode)
        self.model = UnetSkipConnectionBlock(out_planes,ngf, in_planes=in_planes, submodule = unet_block, outmost=True, norm_mode=norm_mode)
    def construct(self, x):
        return self.model(x)

#discriminator
class ConvNormRelu(nn.Cell):
    def __init__(self,
                in_planes,
                out_planes,
                kernel_size = 4,
                stride = 2,
                alpha = 0.2,
                norm_mode = 'batch',
                pad_mode = 'CONSTANT',
                use_relu = True,
                padding = None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm22d(out_planes, affine = False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) //2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode = 'pad', has_bias = has_bias, padding = padding)
            layers = [conv, norm]
        else:    
            paddings = ((0,0), (0,0),(padding,padding), (padding,padding))
            pad = nn.Pad(paddings= paddings, mode = pad_mode)
            conv = nn.Conv2d(in_planes, out_planes,kernel_size, stride, pad_mode = 'pad',has_bias = has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
    def construct(self, x):
        output = self.features(x)
        return output
class Discriminator(nn.Cell):
    def __init__(self, in_planes = 3, ndf = 64, n_layers = 3, alpha = 0.2, norm_mmode = 'batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode = 'pad', padding  = 1),    
            nn.LeakyReLU(alpha)
]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult

            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding = 1))
        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers,8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_nult, kernel_size,1,alpha, norm_mode,padding=1)
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode = 'pad', padding = 1))
        self.features = nn.SequentialCell(layers)
    def construct(self, x, y):
        x_y = ops.concat((x,y),axis = 1)
        output = self.features(x_y)
        return output

        

 It is remarkable that Unet has skip-connection, so different levels of pixel informaiton is preversed.

feature maps are connected together.

next we instantialize the pix2pix generator ad the discriminator.

g_in_planes = 3

g_out_planes = 3

g_ngf =64

g_layers = 8

d_in_planes = 6

d_ndf = 64

d_layers = 3

alpha = 0.2

init_gain = 0.02

init_type = 'normal'

net_generator = UnetGenerator(in_planes= g_in_planes, out_planes=g_out_planes, ngf = g_ngf, n_layers = g_layers)

for _, cell in net_generator.cells_and_names():

        if isinstance(cell, (nn.Conc2d, nn.Conv2dTranspose)):

                if init_type == 'normal':

                        cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))

                elif init_type == 'xavier':

                        cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))

                elif init_type == 'constant':

                        cell.weight.set_data(init.initializer(0.001,cell.weight.shape))

                else:

                        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)

        elif isinstance(cell, nn.BatchNorm2d):

                cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))

                cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

net_discriminator = Discriminator(in_planes=d_in_planes, ndf = d_ndf, alpha=alpha, n_layers = d_layers)

for _,cell in net_discriminator.cells_and_names():

        if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):

                if init_type == 'normal':

                        cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))

                elif init_type == 'xavier':

                        cell.weight.set_data(init.initializer(init.XavierUniform(init_gain),cell.weight.shape))

                elif init_type == 'constant':

                        cell.weight.set_data(init.initializer(0.001,cell.weight.shape))

                else :

                        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)

        elif isinstance(cell, nn.BatchNorm2d):

                cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))

                cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):

        def __init__(self, discriminator, generator):

                super(Pix2Pix, self).__init__(auto_prefix=True)

                self.net_discriminator = discriminator

                self.net_generator=  generator

        def construct(self, reala):

                fakeb = self.net_generator(reala)

                return fakeb

nothing to say more.

epoch_num = 3
ckpt_dir = 'results/ckpt'
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100
def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))
dataset = ds.MindDataset('./dataset/dataset_pix2pix/train.mindrecord', columns_list = ['input_images', 'target_images'], shuffle = True, num_parallel_workers = 1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
def forward_fn(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0  = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis
def forward_gan(reala, realb):
    lambda_gan =0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, realb)
    loss_1 = loss_f (pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2  * lambda_l1
    return loss_gan
grad_d = value_and_grad(forward_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forward_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d (reala, realb)
    loss_gan, g_grads = grad_g (reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan
if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)
g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs = epoch_num)
for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data['input_images'])
        target_image = Tensor(data['target_images'])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step: {:.2f} epoch:{}/{} Dloss: {:.4f} Gloss: {:.4f} ".format((delta / 1000), (epoch+1) , (epoch_num), i, steps_per_epoch,float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1)  == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir+"Generator.ckpt")
# how to predict?
param_g = load_checkpoint(ckpt_dir + 'Generator.ckpt')
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset('./dataset/dataset_pix2pix/train.mindrecord', columns_list = ['input_images', 'target_images'], shuffle = True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter['input_images'])
plt.figure(figsize = (10, 3),dpi = 140)

for i in range(10):
    plt.subplot(2,10,i+1)
    plt.imshow((data_iter['input_images'][i].asnumpy().transpose(1,2,0) + 1) / 2)
    plt.axis ('off')
    plt.subplots_adjust(wspace = 0.05, hspace = 0.02)
    plt.subplot(2, 10, i+ 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1,2,0) +  1 ) / 2)
    plt.axis('off')
    plt.subplots_adjust(wspace = 0.05, hspace = 0.02)
 plt.show()

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐