> 文章列表 > gan实战(基础GAN、DCGAN)

gan实战(基础GAN、DCGAN)

gan实战(基础GAN、DCGAN)

一、基础Gan

1.1 参数

(1)输入:会被放缩到6464
(2)输出:64
64
(3)数据集:

1.2 实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time
from torch.utils import data
from PIL import Image
import glob# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
#     transforms.Grayscale(num_output_channels=1),transforms.Resize(64),transforms.ToTensor(),  # 会做0-1归一化,也会channels, height, widthtransforms.Normalize((0.5,), (0.5,))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)images_path = glob.glob('./data/yellow/*.png')
BATCH_SIZE = 16
dataset = FaceDataset(images_path)
dataLoader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 64*64*3),nn.Tanh())def forward(self, x):img = self.main(x)img = img.view(-1, 3, 64, 64)return img# 判别器网络定义
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(64*64*3, 512),nn.LeakyReLU(),nn.Linear(512, 256),nn.LeakyReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(-1, 64*64*3)x = self.main(x)return xdevice = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.00001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)# 损失函数
loss_fn = torch.nn.BCELoss()# 绘图函数
def gen_img_plot(model, test_input):prediction = np.squeeze(model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy())fig = plt.figure(figsize=(20, 160))for i in range(8):plt.subplot(1, 8, i+1)plt.imshow((prediction[i] + 1)/2)plt.axis('off')plt.show()# step绘图函数
def gen_img_plot_step(img_data, step):predictions = img_data.permute(0, 2, 3, 1).detach().cpu().numpy()print("step:", step)fig = plt.figure(figsize=(3, 3))for i in range(1):plt.imshow((predictions[i]+1)/2)plt.show()test_input = torch.randn(8, 100, device=device)# GAN训练
D_loss = []
G_loss = []# 训练循环
for epoch in range(500):time_start = time.time()d_epoch_loss = 0g_epoch_loss = 0count = len(dataLoader)  # 返回批次数for step, img in enumerate(dataLoader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100, device=device)# 固定生成器,训练判别器d_optimizer.zero_grad()real_output = dis(img)  # 对判别器输入真实图片, real_output是对真实图片的判断结果d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 判别器在真实图像上的损失d_real_loss.backward()gen_img = gen(random_noise)
#         gen_img_plot_step(gen_img, step)fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,fake_output对生成图片的预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 判别器在生成图像上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optimizer.step()# 生成器的损失与优化g_optimizer.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成器的损失g_loss.backward()g_optimizer.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print("Epoch:", epoch)gen_img_plot(gen, test_input)time_end = time.time()print("epoch{}花费时间为:{}, d_loss:{}, g_loss:{}".format(epoch, time_end - time_start, d_loss, g_loss))

1.3 实验效果

Epoch: 0
gan实战(基础GAN、DCGAN)
Epoch: 20
gan实战(基础GAN、DCGAN)

Epoch: 40
gan实战(基础GAN、DCGAN)

Epoch: 60
gan实战(基础GAN、DCGAN)

Epoch: 80
gan实战(基础GAN、DCGAN)

Epoch: 100
gan实战(基础GAN、DCGAN)

Epoch: 120
gan实战(基础GAN、DCGAN)

Epoch: 140
gan实战(基础GAN、DCGAN)

Epoch: 150
gan实战(基础GAN、DCGAN)

二、DCGAN

2.1 参数

(1)输入:会被放缩到6464
(2)输出:64
64
(3)数据集:

2.2 实现

import glob
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import oslog_dir = "./model/dcgan.pth"
images_path = glob.glob('./data/xinggan_face/*.jpg')BATCH_SIZE = 32
dataset = FaceDataset(images_path)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
image_batch = next(iter(data_loader))transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)  # 输出:128*16*16self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # 输出:64*32*32self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)  # 输出:3*64*64def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = F.tanh(self.deconv3(x))return x# 定义判别器
class Discrimination(nn.Module):def __init__(self):super(Discrimination, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2)  # 64*31*31self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2)  # 128*15*15self.bn1 = nn.BatchNorm2d(128)self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn1(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return x# 定义可视化函数
def generate_and_save_images(model, epoch, test_noise_):predictions = model(test_noise_).permute(0, 2, 3, 1).cpu().numpy()fig = plt.figure(figsize=(20, 160))for i in range(predictions.shape[0]):plt.subplot(1, 8, i+1)plt.imshow((predictions[i]+1)/2)# plt.axis('off')plt.show()# 训练函数
def train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch):print("开始训练")test_noise = torch.randn(8, 100, device=device)writer = SummaryWriter(r'D:\\Project\\PythonProject\\Ttest\\run')writer.add_graph(gen, test_noise)##D_loss = []G_loss = []# 开始训练for epoch in range(start_epoch, 500):D_epoch_loss = 0G_epoch_loss = 0batch_count = len(data_loader)   # 返回批次数for step, img, in enumerate(data_loader):img = img.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device)  # 生成随机输入# 固定生成器,训练判别器dis_opti.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))d_real_loss.backward()generated_img = gen(random_noise)# print(generated_img)fake_output = dis(generated_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))d_fake_loss.backward()dis_loss = d_real_loss + d_fake_lossdis_opti.step()# 固定判别器,训练生成器gen_opti.zero_grad()fake_output = dis(generated_img)gen_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))gen_loss.backward()gen_opti.step()with torch.no_grad():D_epoch_loss += dis_loss.item()G_epoch_loss += gen_loss.item()writer.add_scalar("loss/dis_loss", D_epoch_loss / (epoch+1), epoch+1)writer.add_scalar("loss/gen_loss", G_epoch_loss / (epoch+1), epoch+1)with torch.no_grad():D_epoch_loss /= batch_countG_epoch_loss /= batch_countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print("Epoch:{}, 判别器损失:{}, 生成器损失:{}.".format(epoch, dis_loss, gen_loss))generate_and_save_images(gen, epoch, test_noise)state = {"gen": gen.state_dict(),"dis": dis.state_dict(),"gen_opti": gen_opti.state_dict(),"dis_opti": dis_opti.state_dict(),"epoch": epoch}torch.save(state, log_dir)plt.plot(range(1, len(D_loss)+1), D_loss, label="D_loss")plt.plot(range(1, len(D_loss)+1), G_loss, label="G_loss")plt.xlabel('epoch')plt.legend()plt.show()if __name__ == '__main__':device = "cuda:0" if torch.cuda.is_available() else "cpu"gen = Generator().to(device)dis = Discrimination().to(device)loss_fn = torch.nn.BCELoss()gen_opti = torch.optim.Adam(gen.parameters(), lr=0.0001)dis_opti = torch.optim.Adam(dis.parameters(), lr=0.00001)start_epoch = 0if os.path.exists(log_dir):checkpoint = torch.load(log_dir)gen.load_state_dict(checkpoint["gen"])dis.load_state_dict(checkpoint["dis"])gen_opti.load_state_dict(checkpoint["gen_opti"])dis_opti.load_state_dict(checkpoint["dis_opti"])start_epoch = checkpoint["epoch"]print("模型加载成功,epoch从{}开始训练".format(start_epoch))train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch)

2.3 实验效果

开始训练
Epoch:0, 判别器损失:1.6549043655395508, 生成器损失:0.7864767909049988.
gan实战(基础GAN、DCGAN)
Epoch:20, 判别器损失:1.3690211772918701, 生成器损失:0.6662370562553406.
gan实战(基础GAN、DCGAN)
Epoch:40, 判别器损失:1.413375735282898, 生成器损失:0.7497923970222473.
gan实战(基础GAN、DCGAN)
Epoch:60, 判别器损失:1.2889504432678223, 生成器损失:0.8668195009231567.
gan实战(基础GAN、DCGAN)
Epoch:80, 判别器损失:1.2824485301971436, 生成器损失:0.805076003074646.
gan实战(基础GAN、DCGAN)
Epoch:100, 判别器损失:1.3278448581695557, 生成器损失:0.7859240770339966.
gan实战(基础GAN、DCGAN)
Epoch:120, 判别器损失:1.39650297164917, 生成器损失:0.7616179585456848.
gan实战(基础GAN、DCGAN)
Epoch:140, 判别器损失:1.3387322425842285, 生成器损失:0.811163067817688.
gan实战(基础GAN、DCGAN)
Epoch:160, 判别器损失:1.1281094551086426, 生成器损失:0.7557946443557739.

gan实战(基础GAN、DCGAN)
Epoch:180, 判别器损失:1.369300365447998, 生成器损失:0.5207887887954712.

gan实战(基础GAN、DCGAN)