深度学习系列(2):前向传播和后向传播算法

前言

讲真,之前学吴恩达的机器学习课时,还手写实现过后向传播算法,但如今忘得也一干二净。总结两个原因:1. 理解不够透彻。2. 没有从问题的本质抓住后向传播的精髓。今天重温后向传播算法的推导,但重要的是比较前向传播和后向传播的优缺点,以及它们在神经网络中起到了什么不一般的作用,才让我们如此着迷。

反向传播的由来

反向传播由Hinton在1986年发明,该论文发表在nature上,高尚大的杂志啊。

Rumelhart, David E, G. E. Hinton, and R. J. Williams. “Learning representations by back-propagating errors. ” Nature 323.6088(1986):533-536.

简单说说吧,反向传播主要解决神经网络在训练模型时的参数更新问题。神经网络如下图:
alt text

反向传播算法需要解决每条边对应的权值如何更新,才能使得整个输出的【损失函数】最小。如果对神经网络还不了解,建议先学习了什么是神经网络,再阅读以下内容。

这里推荐几篇关于神经网络的文章,总体来说不错:

  1. 计算机的潜意识
  2. Machine Learning & Algorithm 神经网络基础

关于反向传播算法有种不太恰当的比方,对于每个输出结点,给定一个输入样例,会得到一个预测值,而这个预测值和真实值之间的差距我们当作误差(欠的钱),是谁影响了欠债的多少呢?很明显,在神经网络模型中,只有待求的参数 {w1,w2,...,wn} <script type="math/tex" id="MathJax-Element-1">\{w_1, w_2, ..., w_n\}</script>了。如何衡量每个参数对误差的影响,我们定义一个敏感度: 当参数 wi <script type="math/tex" id="MathJax-Element-2">w_i</script>在某个很小的范围内变动时,误差变动了多少,用数学表示即: ΔLΔwi <script type="math/tex" id="MathJax-Element-3">\frac{\Delta L}{\Delta w_i}</script>,在考虑极限情况下,即微分: Lwi <script type="math/tex" id="MathJax-Element-4">\frac{\partial L}{\partial w_i}</script>。

所以我们有了最基础的微分表达式,也是反向传播所有推导公式的源泉,那为什么这个敏感度就能更新权值呢?其实 ΔLΔwi <script type="math/tex" id="MathJax-Element-5">\frac{\Delta L}{\Delta w_i}</script>很有意思,因为不管最终 L(w) <script type="math/tex" id="MathJax-Element-6">L(w)</script>的形式是什么样子, ΔLΔwi= <script type="math/tex" id="MathJax-Element-7">\frac{\Delta L}{\Delta w_i} = 定值</script>,所以假设 Δwi>0 <script type="math/tex" id="MathJax-Element-8">\Delta w_i > 0</script>,那么该定值为负数的情况下, wi <script type="math/tex" id="MathJax-Element-9">w_i</script>增大的方向上 L(wi) <script type="math/tex" id="MathJax-Element-10">L(w_i)</script>将减小,而该定值为正数的情况时, wi <script type="math/tex" id="MathJax-Element-11">w_i</script>增大的方向上 L(wi) <script type="math/tex" id="MathJax-Element-12">L(w_i)</script>将增大。

所以梯度下降的更新算法有 w:=wηLwi <script type="math/tex" id="MathJax-Element-13">w := w - \eta \frac{\partial L}{\partial w_i}</script>,当然你也可以画图形象的理解下,不难。

那么 Lwi <script type="math/tex" id="MathJax-Element-14">\frac{\partial L}{\partial w_i}</script> 这玩意怎么计算呢?在简单的感知机模型中很容易计算得到,具体可以参考上一篇博文,这里不再赘述了。

反向传播的计算

我很讨厌一上来就来了一堆反向传播的公式以及各种推导。这样没错,简单直接,理解了觉得自己还很牛逼,结果过了一段时间怎么又忘了公式的推导,还得重新推一遍。而理解反向传播的精髓并非这些公式的推导,而是它弥补了前向算法的哪些不足,为啥它就被遗留下来作为神经网络的鼻祖呢?解决了什么问题,如何优雅的解决了该问题?从哪些角度能让我们构建出反向传播算法才是应该去学习和理解的。

我们先来建个简单的神经网络图吧,注意,这里只是帮助理解反向传播算法的构建过程,与真实的神经网络有一定的差距,但其中的分析过程是大同小异的。

此外这三篇文章写的不错,【推导】【本质】【实现】都有了:

  1. 【看看就行】机器学习:一步步教你理解反向传播方法
  2. 【后续内容基于此文,推荐】Calculus on Computational Graphs: Backpropagation
  3. 【python实现ANN,只要42行!】A Neural Network in 11 lines of Python (Part 1)

如图所示:
这里写图片描述

为了简化推导过程,输入层只使用了一个特征,同样输出层也只有一个结点,隐藏层使用了两个结点。注意在实际神经网络中,大多数文章把z1和h1当作一个结点来画图的,这里为了方便推导才把两者分开。

所以我们有:
z1=w1x <script type="math/tex" id="MathJax-Element-15">z_1 = w_1 x </script>
z2=w2x <script type="math/tex" id="MathJax-Element-16">z_2 = w_2 x </script>
h1=11+ez1 <script type="math/tex" id="MathJax-Element-17">h_1 = \frac{1}{1 + e^{-z_1}}</script>
h2=11+ez2 <script type="math/tex" id="MathJax-Element-18">h_2 = \frac{1}{1 + e^{-z_2}}</script>
z3=w3h1+w4h2 <script type="math/tex" id="MathJax-Element-19">z_3 = w_3 h_1 + w_4 h_2</script>
y=11+ez3 <script type="math/tex" id="MathJax-Element-20">y = \frac{1}{1 + e^{-z_3}}</script>

假定给了输入x,我们就能根据这一系列公式求得y,接下来我们需要定义损失函数了,使用平方误差函数(只针对一次输入):

L=12(yt)2
<script type="math/tex; mode=display" id="MathJax-Element-21">L = \frac{1}{2} (y - t)^2</script>

t <script type="math/tex" id="MathJax-Element-22">t</script>表示真实值,ok,根据第一节的内容,模型训练实际上是更新wi<script type="math/tex" id="MathJax-Element-23">w_i</script>,既然要更新 wi <script type="math/tex" id="MathJax-Element-24">w_i</script>,就需要求解 Lwi <script type="math/tex" id="MathJax-Element-25">\frac{\partial L}{\partial w_i}</script>,于是对于 wi <script type="math/tex" id="MathJax-Element-26">w_i</script>,根据链式法则,可以求得:

Lw1=Lyyz3z3h1h1z1z1w1
<script type="math/tex; mode=display" id="MathJax-Element-1539"> \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial z_3}\frac{\partial z_3}{\partial h_1}\frac{\partial h_1}{\partial z_1}\frac{\partial z_1}{\partial w_1} </script>

到这里你能看出什么?不着急,我们再求一个 w3 <script type="math/tex" id="MathJax-Element-1540">w_3</script>:

Lw3=Lyyz3z3w3
<script type="math/tex; mode=display" id="MathJax-Element-1541"> \frac{\partial L}{\partial w_3} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial z_3} \frac{\partial z_3}{\partial w_3} </script>

从中,我们可以看到一些模式(规律),实际上 w1 <script type="math/tex" id="MathJax-Element-1542">w_1</script>的更新,在它相关的路径上,每条边的后继和前继结点对应的就是偏导的分子和分母。 w3 <script type="math/tex" id="MathJax-Element-1543">w_3</script>同样如此,它的相关边有三条(最后y指向L的关系边没有画出来),而对应的链式法则也恰好有三个偏导。

结论:每条关系边对应于一个偏导!!!什么是关系边?Okay,就是中间变量如 z1,h1 <script type="math/tex" id="MathJax-Element-1544">z_1, h_1</script>都与 w1 <script type="math/tex" id="MathJax-Element-1545">w_1</script>有关系,连接这些结点的边。

咱们继续细化上述公式,目前来看,这跟反向传播八竿子打不着。的确就这些性质不足以引出反向传播,不着急,继续往下看。

因为偏导数中的每个函数映射都是确定的,所以我们可以求出所有偏导数,于是有:

Lw1=(yt)y(1y)w3h1(1h1)x
<script type="math/tex; mode=display" id="MathJax-Element-1546"> \frac{\partial L}{\partial w_1} = (y - t)\cdot y \cdot (1 - y) \cdot w_3 \cdot h_1 \cdot (1 - h_1) \cdot x </script>

很有意思,式中 x,t <script type="math/tex" id="MathJax-Element-1547">x, t</script>是由样本给定,而那些 y,h1,w3 <script type="math/tex" id="MathJax-Element-1548">y, h_1, w_3</script>都在计算 y <script type="math/tex" id="MathJax-Element-1549">y</script>时,能够得到,这就意味着所有变量都是已知的,可以直接求出Lw1<script type="math/tex" id="MathJax-Element-1550">\frac{\partial L}{\partial w_1}</script>,那怎么就有了前向和后向【传播】之说呢?

宏观上,其实可以考虑一个非常大型的神经网络,它的参数 wi <script type="math/tex" id="MathJax-Element-1551">w_i</script>可能有成千上万个,难道对于每一个参数我们都要列出一个偏导公式么,显然不现实。因此,我们还需要进一步挖掘它们共通的模式。

继续看图:
alt text

假设我们加入第二个特征 x2 <script type="math/tex" id="MathJax-Element-1552">x_2</script>,那么对应的 w5 <script type="math/tex" id="MathJax-Element-1553">w_5</script>的更新,我们有如下公式:

Lw5=Lyyz3z3h1h1z1z1w5
<script type="math/tex; mode=display" id="MathJax-Element-1554"> \frac{\partial L}{\partial w_5} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial z_3}\frac{\partial z_3}{\partial h_1}\frac{\partial h_1}{\partial z_1}\frac{\partial z_1}{\partial w_5} </script>

对比一波 w1 <script type="math/tex" id="MathJax-Element-1555">w_1</script>:

Lw1=Lyyz3z3h1h1z1z1w1
<script type="math/tex; mode=display" id="MathJax-Element-1556"> \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial z_3}\frac{\partial z_3}{\partial h_1}\frac{\partial h_1}{\partial z_1}\frac{\partial z_1}{\partial w_1} </script>

有什么不同么,眼神不好的还以为没有区别,实际上就最后一个偏导的分母发生了变化,而我们刚才也总结出了一个重要结论,每个偏导代表一条边,所以对于 w5 <script type="math/tex" id="MathJax-Element-1557">w_5</script>的更新,前面四个偏导值都需要重新在计算一遍,也就是红线指出的部分,为了算 w5 <script type="math/tex" id="MathJax-Element-1558">w_5</script>,需要重新再走过 w1 <script type="math/tex" id="MathJax-Element-1559">w_1</script>的部分路径。

所以即使我们用输入 (x1,x2) <script type="math/tex" id="MathJax-Element-1560">(x_1, x_2)</script>求出了每个结点,如 z1,h1,z2,h2,z3,y <script type="math/tex" id="MathJax-Element-1561">z_1, h_1, z_2, h_2, z_3, y</script>的值,为了求出每个 wi <script type="math/tex" id="MathJax-Element-1562">w_i</script>的偏导,需要多次代入这些变量,产生了大量的冗余,另外一点在上面也已经指出,每个 wi <script type="math/tex" id="MathJax-Element-1563">w_i</script>都需要手工求偏导么?庞大的神经网络太复杂了,之所以叫前向传播算法,是因为从输入 (x1,x2) <script type="math/tex" id="MathJax-Element-1564">(x_1, x_2)</script>出发,能够求出对应的所有结点的值,这个过程是正向的。

学过动态规划的可能一下子就能理解反向传播的精髓了,如果我们有个中间变量 δj=Lyyz3z3h1h1z1 <script type="math/tex" id="MathJax-Element-1565">\delta_j = \frac{\partial L}{\partial y}\frac{\partial y}{\partial z_3}\frac{\partial z_3}{\partial h_1}\frac{\partial h_1}{\partial z_1}</script>来存储,那么计算 w1 <script type="math/tex" id="MathJax-Element-1566">w_1</script>和 w5 <script type="math/tex" id="MathJax-Element-1567">w_5</script>时,只要对应的 δjz1w1 <script type="math/tex" id="MathJax-Element-1568">\delta_j \cdot \frac{\partial z_1}{\partial w_1}</script>和 δjz1w5 <script type="math/tex" id="MathJax-Element-1569">\delta_j \cdot \frac{\partial z_1}{\partial w_5}</script>即可。那么中间的子状态只需要计算一次即可,而不是指数型增长。

这和递归记忆化搜索(自顶向下)以及动态规划(自底向上)的两种对偶形式很像,为了解决重复子问题,我们可以反向传播,如果能够定义出合适的子状态,且得出递推式那么这件事就做成了。

Okay,再来对比下 w1 <script type="math/tex" id="MathJax-Element-1570">w_1</script>和 w3 <script type="math/tex" id="MathJax-Element-1571">w_3</script>的偏导,继续找找规律吧:

Lw1Lw3=Lyyz3z3h1h1z1z1w1=Lyyz3z3w3
<script type="math/tex; mode=display" id="MathJax-Element-1572"> \begin{align} \frac{\partial L}{\partial w_1} & = \frac{\partial L}{\partial y}\frac{\partial y}{\partial z_3}\frac{\partial z_3}{\partial h_1}\frac{\partial h_1}{\partial z_1}\frac{\partial z_1}{\partial w_1} \\ \frac{\partial L}{\partial w_3} & = \frac{\partial L}{\partial y} \frac{\partial y}{\partial z_3} \frac{\partial z_3}{\partial w_3} \end{align} </script>

两个式子,找找相同的,只有前两部分是一样的,所以可以令 δ1=Lyyz3 <script type="math/tex" id="MathJax-Element-1573">\delta^1 = \frac{\partial L}{\partial y} \frac{\partial y}{\partial z_3} </script>,这样的好处在于:

w3 <script type="math/tex" id="MathJax-Element-1574">w_3</script>时,可以有:

Lw3=δ1z3w3
<script type="math/tex; mode=display" id="MathJax-Element-1575"> \frac{\partial L}{\partial w_3} = \delta^1 \frac{\partial z_3}{\partial w_3} </script>

w5 <script type="math/tex" id="MathJax-Element-1576">w_5</script>时,可以有:

Lw1=δ1z3h1h1z1z1w1
<script type="math/tex; mode=display" id="MathJax-Element-1577"> \frac{\partial L}{\partial w_1} = \delta^1 \frac{\partial z_3}{\partial h_1}\frac{\partial h_1}{\partial z_1}\frac{\partial z_1}{\partial w_1} </script>

从图上来理解的话, δ1 <script type="math/tex" id="MathJax-Element-1578">\delta^1</script> 表示【聚集】在 z3 <script type="math/tex" id="MathJax-Element-1579">z_3</script>的误差,为啥到 z3 <script type="math/tex" id="MathJax-Element-1580">z_3</script>呢,因为在这里刚好可以求出 w3 <script type="math/tex" id="MathJax-Element-1581">w_3</script>的偏导,从公式上理解的话就是那公共部分(重复子问题)。

既然这么定义了,我们可以同样定义第二层的误差 δ21 <script type="math/tex" id="MathJax-Element-1582">\delta_1^2</script> 表示【聚集】在 z1 <script type="math/tex" id="MathJax-Element-1583">z_1</script>的误差。 δ22 <script type="math/tex" id="MathJax-Element-1584">\delta_2^2</script>表示【聚集】在 z2 <script type="math/tex" id="MathJax-Element-1585">z_2</script>的误差。所以有:

δ21=δ1z3h1h1z1=δ1w3h1z1
<script type="math/tex; mode=display" id="MathJax-Element-1586"> \begin{align} \delta_1^2 &= \delta^1 \frac{\partial z_3}{\partial h_1}\frac{\partial h_1}{\partial z_1} \\ & = \delta^1\cdot w_3 \cdot \frac{\partial h_1}{\partial z_1} \end{align} </script>

对应地 w1 <script type="math/tex" id="MathJax-Element-1587">w_1</script>的偏导公式就可以有 Lw1=δ21z1w1 <script type="math/tex" id="MathJax-Element-1588">\frac{\partial L}{\partial w_1} = \delta_1^2 \frac{\partial z_1}{\partial w_1} </script>

哈哈哈,对比一波 w1,w5,w3 <script type="math/tex" id="MathJax-Element-1589">w_1, w_5, w_3</script>,可以得到:

Lw1Lw5Lw3=δ21z1w1=δ21z1w5=δ1z3w3
<script type="math/tex; mode=display" id="MathJax-Element-1590"> \begin{align} \frac{\partial L}{\partial w_1} & = \delta_1^2 \frac{\partial z_1}{\partial w_1} \\ \frac{\partial L}{\partial w_5} & = \delta_1^2 \frac{\partial z_1}{\partial w_5} \\ \frac{\partial L}{\partial w_3} & = \delta^1 \frac{\partial z_3}{\partial w_3} \end{align} </script>

别迷糊了,它们都属于同一种形式,写算法就好些很多,而 δ2 <script type="math/tex" id="MathJax-Element-1591">\delta^2</script>是由 δ1 <script type="math/tex" id="MathJax-Element-1592">\delta^1</script>加上对应的 wi <script type="math/tex" id="MathJax-Element-1593">w_i</script>求得,形象了吧,所以我们首要的目标是求出最后一层的 δ1 <script type="math/tex" id="MathJax-Element-1594">\delta^1</script>,接着就能根据前一层的权值 wi <script type="math/tex" id="MathJax-Element-1595">w_i</script>求出前一层每个结点的 δ2 <script type="math/tex" id="MathJax-Element-1596">\delta^2</script>,更新公式都一样, δ2 <script type="math/tex" id="MathJax-Element-1597">\delta^2</script>乘以上一层的输出值而已,谁叫 y=h1w1+h2w2 <script type="math/tex" id="MathJax-Element-1598">y = h_1w_1 + h_2w_2</script>是线性的呢,求偏导 h1 <script type="math/tex" id="MathJax-Element-1599">h_1</script>得到 w1 <script type="math/tex" id="MathJax-Element-1600">w_1</script>,求偏导 w1 <script type="math/tex" id="MathJax-Element-1601">w_1</script>得 h1 <script type="math/tex" id="MathJax-Element-1602">h_1</script>,这实在太巧妙了。

这就对了吗?不,离真正的反向传播推导出的公式还差那么一点点,继续看图:
alt text

我们按照关系边的概念,可以知道 w5 <script type="math/tex" id="MathJax-Element-1603">w_5</script>的关系边应该由红色的边组成。所以 δ21 <script type="math/tex" id="MathJax-Element-1604">\delta_1^2</script>的更新不仅仅只跟 z3 <script type="math/tex" id="MathJax-Element-1605">z_3</script>有关系了,还和 z4 <script type="math/tex" id="MathJax-Element-1606">z_4</script>有关。为什么?此时损失函数由两部分组成,对应一个输入样例(x1, x2),有:

L=12(y1t1)2+12(y2t2)2
<script type="math/tex; mode=display" id="MathJax-Element-1607"> L = \frac{1}{2}(y_1 - t_1)^2 + \frac{1}{2}(y_2 - t_2)^2 </script>

所以对 L <script type="math/tex" id="MathJax-Element-1608">L</script>求偏导,由加法法则,可以得到Lw5=Ly1+Ly2<script type="math/tex" id="MathJax-Element-1609">\frac{\partial L}{\partial w_5} = \frac{\partial L}{\partial y_1} + \frac{\partial L}{\partial y_2}</script>,没错,多个结点指向同一个结点时,把它们的偏导值加起来即可(损失函数就这么定义)。所以 δ2j=hjzjwjiδ1i <script type="math/tex" id="MathJax-Element-1610">\delta_j^2 =\frac{\partial h_j}{\partial z_j} \sum w_{ji} \cdot \delta_i^1 </script>.

此时再看看完整的反向传播公式推导吧,或许就明白其中缘由了。参考链接:http://blog.csdn.net/u014313009/article/details/51039334

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐