昇思25天学习打卡营第15天|Pix2Pix
【代码】昇思25天学习打卡营第15天|Pix2Pix。
·

It maybe easy to understand pix to pix if you have known DCGAN.
data.
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar'
download(url, './dataset', kind = 'tar', replace = True)
#visualize them
dataset =ds.MindDataset('./dataset/dataset_pix2pix/train.mindrecord',columns_list=['input_images','target_images'],shuffle = True)
data_iter = next(dataset.create_dict_iterator(output_numpy = True))
plt.figure(figsize = (10,3), dpi = 140)
for i,image in enumerate(data_iter['input_images'][:10],1):
plt.subplot(3, 10, i)
plt.axis('off')
plt.imshow((image.transpose(1,2,0) + 1)/2)
plt.show()

we get some basic elements of pix2pix. It is now replaced by dcgan, so the principle is what we should focus.
class UnetSkipConnectionBlock(nn.Cell):
def __init__(self,outer_nc,inner_nc, inplanes=None, dropout = False,
submodule = None, outermost=False, innermost=False,alpha=0.2,norm_mode = 'batch'):
super(UnetSkipConnectionBlock, self).__init__()
down_norm = nn.BatchNorm2d(inner_nc)
up_norm = nn.BatchNorm2d(outer_nc)
use_bias = False
if norm_mode == 'instance':
down_norm = nn.InstanceNorm2d(inner_nc)
up_norm = nn.InstanceNorm2d(outer_nc)
use_bias = True
if inplanes is None:
inplanes = outer_nc
down_conv = nn.Conv2d(inplanes, inner_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias,
pad_mode='pad')
down_relu = nn.LeakyReLU(alpha)
up_relu = nn.ReLU()
if outermost:
up_conv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias,
pad_mode='pad')
down = [down_conv]
up = [up_relu, up_conv]
model = down + [submodule] + up
else:
up_conv = nn.Conv2dTranspose(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=use_bias,pad_mode='pad')
down = [down_relu, down_conv, down_norm]
model = down+[submodule]+up
if dropout:
model.append(nn.Dropout(0.5))
def construct(self, x):
out = self.model(x)
if self.skip_connections:
out = ops.concat((out,x ),axis = 1)
return out
class UnetGenerator(nn.Cell):
def __init__(self, in_planes, out_planes, ngf = 64, n_layers = 8, norm_mode = 'bn', dropout = False):
super(UnetGenerator, self).__init__()
unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=None, innermost=True, norm_mode=norm_mode)
for _ in range(n_layers-5):
unet_block = UnetSkipConnectionBlock(ngf*8, ngf*8, submodule=unet_block, norm_mode=norm_mode, dropout=dropout)
unet_block = UnetSkipConnectionBlock(ngf*4, ngf*8, submodule=unet_block, norm_mode=norm_mode)
unet_block = UnetSkipConnectionBlock(ngf*2, ngf*4, submodule=unet_block, norm_mode=norm_mode)
unet_block = UnetSkipConnectionBlock(ngf, ngf*2, submodule=unet_block, norm_mode=norm_mode)
unet_block = UnetSkipConnectionBlock(out_planes, ngf, inplanes=in_planes, submodule=unet_block, outermost=True, norm_mode=norm_mode)
self.model = UnetSkipConnectionBlock(out_planes,ngf, in_planes=in_planes, submodule = unet_block, outmost=True, norm_mode=norm_mode)
def construct(self, x):
return self.model(x)
#discriminator
class ConvNormRelu(nn.Cell):
def __init__(self,
in_planes,
out_planes,
kernel_size = 4,
stride = 2,
alpha = 0.2,
norm_mode = 'batch',
pad_mode = 'CONSTANT',
use_relu = True,
padding = None):
super(ConvNormRelu, self).__init__()
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
norm = nn.BatchNorm22d(out_planes, affine = False)
has_bias = (norm_mode == 'instance')
if not padding:
padding = (kernel_size - 1) //2
if pad_mode == 'CONSTANT':
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode = 'pad', has_bias = has_bias, padding = padding)
layers = [conv, norm]
else:
paddings = ((0,0), (0,0),(padding,padding), (padding,padding))
pad = nn.Pad(paddings= paddings, mode = pad_mode)
conv = nn.Conv2d(in_planes, out_planes,kernel_size, stride, pad_mode = 'pad',has_bias = has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
def construct(self, x):
output = self.features(x)
return output
class Discriminator(nn.Cell):
def __init__(self, in_planes = 3, ndf = 64, n_layers = 3, alpha = 0.2, norm_mmode = 'batch'):
super(Discriminator, self).__init__()
kernel_size = 4
layers = [
nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode = 'pad', padding = 1),
nn.LeakyReLU(alpha)
]
nf_mult = ndf
for i in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** i, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding = 1))
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers,8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_nult, kernel_size,1,alpha, norm_mode,padding=1)
layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode = 'pad', padding = 1))
self.features = nn.SequentialCell(layers)
def construct(self, x, y):
x_y = ops.concat((x,y),axis = 1)
output = self.features(x_y)
return output
It is remarkable that Unet has skip-connection, so different levels of pixel informaiton is preversed.
feature maps are connected together.

next we instantialize the pix2pix generator ad the discriminator.
g_in_planes = 3
g_out_planes = 3
g_ngf =64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'
net_generator = UnetGenerator(in_planes= g_in_planes, out_planes=g_out_planes, ngf = g_ngf, n_layers = g_layers)
for _, cell in net_generator.cells_and_names():
if isinstance(cell, (nn.Conc2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001,cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
net_discriminator = Discriminator(in_planes=d_in_planes, ndf = d_ndf, alpha=alpha, n_layers = d_layers)
for _,cell in net_discriminator.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain),cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001,cell.weight.shape))
else :
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class Pix2Pix(nn.Cell):
def __init__(self, discriminator, generator):
super(Pix2Pix, self).__init__(auto_prefix=True)
self.net_discriminator = discriminator
self.net_generator= generator
def construct(self, reala):
fakeb = self.net_generator(reala)
return fakeb
nothing to say more.
epoch_num = 3
ckpt_dir = 'results/ckpt'
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100
def get_lr():
lrs = [lr] * dataset_size * n_epochs
lr_epoch = 0
for epoch in range(n_epochs_decay):
lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
lrs += [lr_epoch] * dataset_size
lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
return Tensor(np.array(lrs).astype(np.float32))
dataset = ds.MindDataset('./dataset/dataset_pix2pix/train.mindrecord', columns_list = ['input_images', 'target_images'], shuffle = True, num_parallel_workers = 1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
def forward_fn(reala, realb):
lambda_dis = 0.5
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
pred1 = net_discriminator(reala, realb)
loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
loss_dis = loss_d * lambda_dis
return loss_dis
def forward_gan(reala, realb):
lambda_gan =0.5
lambda_l1 = 100
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, realb)
loss_1 = loss_f (pred0, ops.ones_like(pred0))
loss_2 = l1_loss(fakeb, realb)
loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
return loss_gan
grad_d = value_and_grad(forward_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forward_gan, None, net_generator.trainable_params())
def train_step(reala, realb):
loss_dis, d_grads = grad_d (reala, realb)
loss_gan, g_grads = grad_g (reala, realb)
d_opt(d_grads)
g_opt(g_grads)
return loss_dis, loss_gan
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs = epoch_num)
for epoch in range(epoch_num):
for i, data in enumerate(data_loader):
start_time = datetime.datetime.now()
input_image = Tensor(data['input_images'])
target_image = Tensor(data['target_images'])
dis_loss, gen_loss = train_step(input_image, target_image)
end_time = datetime.datetime.now()
delta = (end_time - start_time).microseconds
if i % 2 == 0:
print("ms per step: {:.2f} epoch:{}/{} Dloss: {:.4f} Gloss: {:.4f} ".format((delta / 1000), (epoch+1) , (epoch_num), i, steps_per_epoch,float(dis_loss), float(gen_loss)))
d_losses.append(dis_loss.asnumpy())
g_losses.append(gen_loss.asnumpy())
if (epoch + 1) == epoch_num:
mindspore.save_checkpoint(net_generator, ckpt_dir+"Generator.ckpt")
# how to predict?
param_g = load_checkpoint(ckpt_dir + 'Generator.ckpt')
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset('./dataset/dataset_pix2pix/train.mindrecord', columns_list = ['input_images', 'target_images'], shuffle = True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter['input_images'])
plt.figure(figsize = (10, 3),dpi = 140)
for i in range(10):
plt.subplot(2,10,i+1)
plt.imshow((data_iter['input_images'][i].asnumpy().transpose(1,2,0) + 1) / 2)
plt.axis ('off')
plt.subplots_adjust(wspace = 0.05, hspace = 0.02)
plt.subplot(2, 10, i+ 11)
plt.imshow((predict_show[i].asnumpy().transpose(1,2,0) + 1 ) / 2)
plt.axis('off')
plt.subplots_adjust(wspace = 0.05, hspace = 0.02)
plt.show()

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐


所有评论(0)