本文是一个入门级的GAN代码示例,使用的是MNIST数据集。本文由于是面向新手,所以对代码进行了详细的解释。而且最后给出了完整的代码,可以直接运行看看效果。虽然经过了50个epoch的训练,但是训练的效果并不是很好,这也是为了后面的魔改版GAN做一个铺垫。后面的代码都是在这个版本的基础上进行修改的。所以理解这个基础版代码是非常重要的。

数据准备

# 数据集使用的是Fashion-Minst数据集
# 对数据集做归一化(-1,1)
# 在GAN中,建议将数据归一化到(-1,1)之间,因为生成的图片使用tan激活,tan函数在-1~1之间
transform = transforms.Compose([
    transforms.ToTensor(), #归一化,浮点化,通道重排
    transforms.Normalize(0.5, 0.5)  # 从(0,1)到(-1,1)  
])

这段代码是对Fashion-MNIST数据集进行预处理的代码。它使用了PyTorch中的transforms.Compose函数来定义一系列的数据转换操作。

  1. transforms.ToTensor()将数据转换为Tensor对象,并将像素值从整数(0-255)映射到浮点数(0.0-1.0)范围内。这一步是为了将数据归一化到0到1之间。
  2. transforms.Normalize(0.5, 0.5)将数据进行标准化操作,使其均值为0.5,标准差为0.5。这一步将数据的范围从(0.0-1.0)缩放到(-1.0, 1.0)范围内。标准化操作是为了使数据在训练过程中更易于收敛,并且可以使生成的图片使用tanh激活函数。

综合起来,这段代码的作用是将Fashion-MNIST数据集的图像转换为归一化的张量,并将像素值标准化到(-1, 1)的范围内,以便在GAN中使用。

定义生成器

# 生成器的输入是噪声,长度为100的随机数,正态分布
# 输出为(1,28,28)的图片
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(100, 256),
                                  nn.ReLU(),
                                  nn.Linear(256, 512),
                                  nn.ReLU(),
                                  nn.Linear(512, 28*28),
                                  nn.Tanh()                     # -1, 1之间
        )
    def forward(self, x):              # x 表示长度为100 的noise输入
        img = self.main(x)
        img = img.view(-1, 28, 28)
        return img
  • Generator 类继承自 nn.Module,是一个 PyTorch 的模型类。
  • 在构造函数 __init__() 中,通过 nn.Sequential() 定义了一个包含多个层的顺序网络结构 self.main,其中每个层按照顺序连接。
  • 第一层是全连接层 nn.Linear(100, 256),它将输入的噪声向量的维度从100转换为256。这一层的激活函数是ReLU。
  • 第二层是全连接层 nn.Linear(256, 512),将维度从256转换为512,同样使用ReLU作为激活函数。
  • 第三层是全连接层 nn.Linear(512, 28*28),将维度从512转换为28*28,即图像的像素数。这一层没有激活函数。
  • 最后一层是 nn.Tanh(),它将生成的图像的像素值范围映射到(-1, 1)之间。
  • forward() 方法定义了前向传播的过程,即输入 x 经过网络层的计算得到生成的图像。
  • 输入 x 经过 self.main,经过多个线性变换和激活函数的计算后,得到生成图像的张量 img
  • 接着,通过 img.view(-1, 28, 28)img 进行形状变换,将其从展平的向量形式转换为大小为(1, 28, 28)的图像。
  • 最后,返回生成的图像张量 img

定义判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(28*28, 512),
                                  nn.LeakyReLU(),
                                  nn.Linear(512, 256),
                                  nn.LeakyReLU(),
                                  nn.Linear(256, 1),
                                  nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x

定义了一个判别器模型,用于对生成的图像进行分类,判断其是否为真实图像。

  • Discriminator 类继承自 nn.Module,是一个 PyTorch 的模型类。
  • 在构造函数 __init__() 中,通过 nn.Sequential() 定义了一个包含多个层的顺序网络结构 self.main,其中每个层按照顺序连接。
  • 第一层是全连接层 nn.Linear(28*28, 512),它将输入的图像张量展平成长度为2828的向量,然后将其维度从2828转换为512。
  • 第二层是激活函数 nn.LeakyReLU(),它采用了带泄露的线性整流激活函数,用于引入一定的非线性特征。
  • 第三层是全连接层 nn.Linear(512, 256),将维度从512转换为256。
  • 第四层是激活函数 nn.LeakyReLU(),同样采用了带泄露的线性整流激活函数。
  • 第五层是全连接层 nn.Linear(256, 1),将维度从256转换为1,用于输出一个标量值,表示图像的真假。
  • 最后一层是 nn.Sigmoid(),它将输出的标量值通过 sigmoid 函数映射到(0, 1)的范围,用于表示图像的概率。
  • forward() 方法定义了前向传播的过程,即输入 x 经过网络层的计算得到判别结果。
  • 首先,通过 x.view(-1, 28*28) 将输入的图像张量展平成长度为28*28的向量。
  • 然后,将展平后的向量经过 self.main,经过多个线性变换、激活函数的计算后,得到判别结果的张量 x
  • 最后,返回判别结果张量 x

初始化模型、优化器以及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
loss_fn = torch.nn.BCELoss()

这段代码的作用是选择设备、初始化生成器和判别器模型,并定义了判别器和生成器的优化器以及判别器的损失函数,为后续的训练过程做准备。

GAN训练

# #############GAN训练###################
D_loss = []
G_loss = []

# 训练循环
for epoch in range(50):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)
        
        d_optim.zero_grad()
        
        real_output = dis(img)      # 判别器输入真实的图片,real_output对真实图片的预测结果 
        d_real_loss = loss_fn(real_output, 
                              torch.ones_like(real_output))      # 得到判别器在真实图像上的损失
        d_real_loss.backward()
        
        gen_img = gen(random_noise)
        # 判别器输入生成的图片,fake_output对生成图片的预测
        fake_output = dis(gen_img.detach()) 
        d_fake_loss = loss_fn(fake_output, 
                              torch.zeros_like(fake_output))      # 得到判别器在生成图像上的损失
        d_fake_loss.backward()
        
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()
        
        g_optim.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output, 
                         torch.ones_like(fake_output))      # 生成器的损失
        g_loss.backward()
        g_optim.step()
        
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
            
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss.item())
        G_loss.append(g_epoch_loss.item())
        print('Epoch:', epoch)
        gen_img_plot(gen, test_input)
  • 创建了两个空列表 D_lossG_loss,用于记录每个epoch的判别器和生成器的损失。、
  • d_epoch_lossg_epoch_loss 初始化为0,用于记录每个epoch的判别器和生成器的累计损失。
  • count 记录了数据加载器中的批次数量,即图像的总批次数。
  • 在每个epoch中,通过数据加载器 dataloader 迭代处理图像数据。其中 img 是一个批次的图像数据,_ 是对应的标签(在这里没有使用)。
  • 将图像数据 img 移动到设备上。
  • 通过 torch.randn(size, 100, device=device) 生成与图像批次大小相匹配的随机噪声,作为生成器的输入。
  • d_optim.zero_grad() 清零判别器的梯度。
  • 通过判别器 dis 对真实图像 img 进行前向传播,得到判别器在真实图像上的输出 real_output
  • 计算判别器在真实图像上的损失 d_real_loss,使用二元交叉熵损失函数,将判别器的输出与全为1的目标标签进行比较。
  • 对判别器的真实图像损失进行反向传播 d_real_loss.backward()
  • 使用生成器 gen 对随机噪声 random_noise 进行前向传播,生成一批假图像 gen_img
  • 将生成的图像输入判别器 dis,得到判别器在生成图像上的输出 fake_output。注意,使用 detach() 将生成图像的梯度断开,以便只更新判别器的参数。
  • 计算判别器在生成图像上的损失 d_fake_loss,使用二元交叉熵损失函数,将判别器的输出与全为0的目标标签进行比较。

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision 
from torchvision import transforms

# #############数据准备###################
# 数据集使用的是Fashion-Minst数据集
# 对数据集做归一化(-1,1)
# 在GAN中,建议将数据归一化到(-1,1)之间,因为生成的图片使用tan激活,tan函数在-1~1之间
transform = transforms.Compose([
    transforms.ToTensor(), #归一化,浮点化,通道重排
    transforms.Normalize(0.5, 0.5)  # 从(0,1)到(-1,1)  
])

# 训练数据集
# 只需要训练集的图片就可以了,不需要测试集
train_ds = torchvision.datasets.MNIST('data',
                                      train = True,
                                      transform = transform,
                                      download = True
                                     )

dataloader = torch.utils.data.DataLoader(train_ds,batch_size=64, shuffle=True)

# #############定义生成器###################

# 生成器的输入是噪声,长度为100的随机数,正态分布
# 输出为(1,28,28)的图片
# linear1 : 100-->256
# linear2 : 256-->512
# linear3 : 512-->28*28
# reshape :  28*28 -->(1,28,28)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(100, 256),
                                  nn.ReLU(),
                                  nn.Linear(256, 512),
                                  nn.ReLU(),
                                  nn.Linear(512, 28*28),
                                  nn.Tanh()                     # -1, 1之间
        )
    def forward(self, x):              # x 表示长度为100 的noise输入
        img = self.main(x)
        img = img.view(-1, 28, 28)
        return img

# #############定义判别器###################
# 输入为(1,28,28)的图片,输出为二分类的概率值,使用sigmoid激活0-1
# BCEloss交叉熵损失
# 在判别器中一般推荐使用 LeakyReLU
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(28*28, 512),
                                  nn.LeakyReLU(),
                                  nn.Linear(512, 256),
                                  nn.LeakyReLU(),
                                  nn.Linear(256, 1),
                                  nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x
    
# #############初始化模型、优化器以及损失计算函数###################
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
loss_fn = torch.nn.BCELoss()

# #############绘图###################
# 方便将每一个批次训练的结果展示出来,看看生成的图像是什么样子
def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i] + 1)/2)
        plt.axis('off')
    plt.show()

test_input = torch.randn(16, 100, device=device)

# #############GAN训练###################
D_loss = []
G_loss = []

# 训练循环
for epoch in range(50):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)
        
        d_optim.zero_grad()
        
        real_output = dis(img)      # 判别器输入真实的图片,real_output对真实图片的预测结果 
        d_real_loss = loss_fn(real_output, 
                              torch.ones_like(real_output))      # 得到判别器在真实图像上的损失
        d_real_loss.backward()
        
        gen_img = gen(random_noise)
        # 判别器输入生成的图片,fake_output对生成图片的预测
        fake_output = dis(gen_img.detach()) 
        d_fake_loss = loss_fn(fake_output, 
                              torch.zeros_like(fake_output))      # 得到判别器在生成图像上的损失
        d_fake_loss.backward()
        
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()
        
        g_optim.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output, 
                         torch.ones_like(fake_output))      # 生成器的损失
        g_loss.backward()
        g_optim.step()
        
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
            
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss.item())
        G_loss.append(g_epoch_loss.item())
        print('Epoch:', epoch)
        gen_img_plot(gen, test_input)
本文标签: 深度学习GAN

版权声明 ▶ 本网站名称:陶小桃Blog
▶ 本文链接:https://www.52txr.cn/2023/ganminst.html
▶ 本网站的文章部分内容可能来源于网络,仅供大家学习与参考,如有侵权,请联系站长进行核实删除。
▶ 转载本站文章需要遵守:商业转载请联系站长,非商业转载请注明出处并附带原文链接!!
▶ 站长邮箱 [email protected][email protected] ,如不方便留言可邮件联系。

小陶的个人微信公众号

学累了就来张美女照片养养眼吧,身体是革命的本钱,要多多休息哦~ 随机美女图片

最后修改:2023 年 10 月 12 日
如果觉得我的文章对你有用,请随意赞赏!