| 雷峰网
您正在使用IE低版浏览器,为了您的雷峰网账号安全和更好的产品体验,强烈建议使用更快更安全的浏览器
此为临时链接,仅用于文章预览,将在时失效
人工智能开发者 正文
发私信给AI研习社
发送

0

从零教你写一个完整的GAN

本文作者: AI研习社 编辑:贾智龙 2017-04-27 16:24
导语:GAN这么火,不想自己复现一下么?

导言

啦啦啦,现今 GAN 算法可以算作 ML 领域下比较热门的一个方向。事实上,GAN 已经作为一种思想来渗透在 ML 的其余领域,从而做出了很多很 Amazing 的东西。比如结合卷积神经网络,可以用于生成图片。或者结合 NLP,可以生成特定风格的短句子。(比如川普风格的 twitter......)

可惜的是,网络上很多老司机开 GAN 的车最后都翻了,大多只是翻译了一篇论文,一旦涉及算法实现部分就直接放开源的实现地址,而那些开源的东东,缺少了必要的引导,实在对于新手来说很是懵逼。所以兔子哥哥带着开好车,开稳车的心态,特定来带一下各位想入门 GAN 的其他小兔兔们来飞一会。

GAN 的介绍与训练

先来阐述一下 GAN 的基本做法,这里不摆公式,因为你听完后,该怎么搭建和怎么训练你心里应该有数了。

首先,GAN 全称为 Generative Adversarial Nets(生成对抗网络), 其构成分为两部份:

  • Generator(生成器),下文简称 G

  • Discriminator(辨别器), 下文简称 D。

在本文,为了方便小兔兔理解,使用一个较为简单,也是 GAN 论文提及到的例子,训练 G 生成符合指定均值和标准差的数据,在这里,我们指定 MEAN=4,STD=1.5 的高斯分布(正态分布)。

这货的样子大概长这样

从零教你写一个完整的GAN

下面是数据生成的代码:

def sample_data(size, length=100):
   """    随机mean=4 std=1.5的数据    :param size:    :param length:    :return:    """
   data = []
   for _ in range(size):
       data.append(sorted(np.random.normal(4, 1.5, length)))
   return np.array(data)

在生成高斯分布的数据后,我们还对数据进行了排序,这时因为排序后的训练会相对平滑。具体原因看这个 [Generative Adversarial Nets in TensorFlow (Part I)]

这一段是生成噪音的代码,既然是噪音,那么我们只需要随机产生 0~1 的数据就好。

def random_data(size, length=100):
   """    随机生成数据    :param size:    :param length:    :return:    """
   data = []
   for _ in range(size):
       x = np.random.random(length)
       data.append(x)
   return np.array(data)

随机分布的数据长这样

从零教你写一个完整的GAN

接下来就是开撸 GAN 了。

首先的一点就是,我们需要确定 G, 和 D 的具体结构,这里因为本文的例子太过于入门级,并不需要使用到复杂的神经网络结构,比如卷积层和递归层,这里 G 和 D 只需全连接的神经网络就好。全连接层的神经网络本质就是矩阵的花式相乘。为神马说是花式相乘呢,因为大多数时候,我们在矩阵相乘的结果后面会添加不同的激活函数。

从零教你写一个完整的GAN

G 和 D 分别为三层的全链接的神经网络,其中 G 的激活函数分别为,relu,sigmoid,liner,这里前两层只是因为考虑到数据的非线性转换,并没有什么特别选择这两个激活函数的原因。其次 D 的三层分别为 relu,sigmoid,sigmoid。

接下来就引出 GAN 的训练问题。GAN 的思想源于博弈论,直白一点就是套路与反套路。

先从 D 开始分析,D 作为辨别器,它的职责就是区分于真实的高斯分布和 G 生成的” 假” 高斯分布。所以很显然,针对 D 来说,其需要解决的就是传统的二分类问题。

在二分类问题中,我们习惯用交叉熵来衡量分类效果。

从零教你写一个完整的GAN

从公式中不难看出,在全部分类正确时,交叉熵会接近于 0,因此,我们的目标就是通过拟合 D 的参数来最小化交叉熵的值。

D 既然是传统的二分类问题,那么 D 的训练过程也很容易得出了

  • 即先把真实数据标识为‘1’(真实分布),由生成器生成的数据标识为’0‘(生成分布),反复迭代训练 D   ------------ (1)

说 G 的训练之前先来打个比方,假如一男一女在一起了,现在两人性格出现矛盾了,女生并不愿意改变,但两个人都想继续在一起,这时,唯一的方法就是男生改变了。先忽略现实生活的问题,但从举例的角度说,显然久而久之男生就会变得更加 fit 这个女生。

G 的训练也是如此:

  • 先将 G 拼接在 D 的上方,即 G 的输出作为 D 的输入(男生女生在一起),而同时固定 D 的参数(女生不愿意改变),并将进入 G 的噪音样本标签全部改成'1'(目标是两个人继续在一起,没有其他选择),为了最小化损失函数,此时就只能改变 G 的每一层权重,反复迭代后 G 的生成能力因此得以改进。(男生更适合女生)  ------------ (2)

反复迭代(1)(2),最终 G 就会得到较好的生成能力。

补充一点,在训练 D 的时候,我曾把数据直接放进去,这样的后果是最后生成的数据,能学习到高斯分布的轮廓,但标准差和均值则和真实样本相差很大。因此,这里我建议直接使用平均值和标准差作为 D 的输入。

这使得 D 在训练前需要对数据进行预处理。

def preprocess_data(x):
   """    计算每一组数据平均值和方差    :param x:    :return:    """
   return [[np.mean(data), np.std(data)] for data in x]

G 和 D 的连接之间也需要做出处理。

# 先求出G_output3的各行平均值和方差 

MEAN = tf.reduce_mean(G_output3, 1)  # 平均值,但是是1D向量

MEAN_T = tf.transpose(tf.expand_dims(MEAN, 0))  # 转置

STD = tf.sqrt(tf.reduce_mean(tf.square(G_output3 - MEAN_T), 1))

DATA = tf.concat(1, [MEAN_T,
                    tf.transpose(tf.expand_dims(STD, 0))] # 拼接起来

以下是损失函数变化图:

  • 蓝色是 D 单独作二分类问题处理时的变化

  • 绿色是拼接 G 在 D 的上方后损失函数的变化

从零教你写一个完整的GAN

不难看出,两者在经历反复震荡 (互相博弈而导致),最后稳定于 0.5 附近,这时我们可以认为,G 的生成能力已经达到以假乱真,D 再也不能分出真假。

接下来的这个就是 D-G 博弈 200 次后的结果:

  • 绿色是真实分布

  • 蓝色是噪音原本的分布

  • 红色是生成分布

从零教你写一个完整的GAN

后话

兔子哥哥的车这次就开到这里了。作为一个大三且数学能力较为一般的学生, 从比较感性的角度来描述了一次 GAN 的基本过程,有说得不对地方请各位见谅和指点。

如果各位读者需要更严格的数学公式和证明,可以阅读 GAN 的开山之作([1406.2661] Generative Adversarial Networks) , 另外本文提及的代码都可在这里找到(MashiMaroLjc/learn-GAN),有需要的童鞋也可以私信交流。

这就是全部内容了,下次心情好的话怼 DCGAN,看看能不能生成点好玩的图片,嗯~ 睡觉去~

雷锋网按:本文原作者兔子老大,原文来自他的知乎专栏

雷峰网版权文章,未经授权禁止转载。详情见转载须知

从零教你写一个完整的GAN

分享:
相关文章

编辑

聚焦数据科学,连接 AI 开发者。更多精彩内容,请访问:yanxishe.com
当月热门文章
最新文章
请填写申请人资料
姓名
电话
邮箱
微信号
作品链接
个人简介
为了您的账户安全,请验证邮箱
您的邮箱还未验证,完成可获20积分哟!
请验证您的邮箱
立即验证
完善账号信息
您的账号已经绑定,现在您可以设置密码以方便用邮箱登录
立即设置 以后再说
Baidu
map