循环神经网络(RNN)的损失函数与传播过程详解

1. 损失函数

1.1 任务类型

RNN的损失函数取决于具体任务:

  • 分类任务:交叉熵损失
  • 回归任务:均方误差(MSE)损失

1.2 数学表示

对于分类任务,损失函数通常为交叉熵损失:
L=−∑t=1T∑i=1Cyt,ilog⁡(y^t,i) L = -\sum_{t=1}^T \sum_{i=1}^C y_{t,i} \log(\hat{y}_{t,i}) L=t=1Ti=1Cyt,ilog(y^t,i)
其中:

  • TTT为序列长度
  • CCC为类别数
  • yt,iy_{t,i}yt,i为真实标签的one-hot表示
  • y^t,i\hat{y}_{t,i}y^t,i为模型预测的概率分布

2. 前向传播

2.1 基本步骤

  1. 初始化隐藏状态h0h_0h0(通常为零向量)。
  2. 对每个时间步t=1t=1t=1TTT
    • 计算隐藏状态:
      ht=σ(Whhht−1+Wxhxt+bh) h_t = \sigma(W_{hh} h_{t-1} + W_{xh} x_t + b_h) ht=σ(Whhht1+Wxhxt+bh)
    • 计算输出:
      ot=Whyht+by o_t = W_{hy} h_t + b_y ot=Whyht+by
    • 计算预测值:
      y^t=softmax(ot) \hat{y}_t = \text{softmax}(o_t) y^t=softmax(ot)

2.2 数学表示

  • 隐藏状态更新
    ht=tanh⁡(Whhht−1+Wxhxt+bh) h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h) ht=tanh(Whhht1+Wxhxt+bh)
  • 输出计算
    ot=Whyht+by o_t = W_{hy} h_t + b_y ot=Whyht+by
  • 预测值
    y^t=softmax(ot) \hat{y}_t = \text{softmax}(o_t) y^t=softmax(ot)

3. 反向传播

3.1 基本步骤

  1. 计算损失函数LLL对输出oto_tot的梯度:
    ∂L∂ot=y^t−yt \frac{\partial L}{\partial o_t} = \hat{y}_t - y_t otL=y^tyt
  2. 计算损失函数LLL对隐藏状态hth_tht的梯度:
    ∂L∂ht=∂L∂ot∂ot∂ht+∂L∂ht+1∂ht+1∂ht \frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial o_t} \frac{\partial o_t}{\partial h_t} + \frac{\partial L}{\partial h_{t+1}} \frac{\partial h_{t+1}}{\partial h_t} htL=otLhtot+ht+1Lhtht+1
  3. 计算损失函数LLL对参数Whh,Wxh,Why,bh,byW_{hh}, W_{xh}, W_{hy}, b_h, b_yWhh,Wxh,Why,bh,by的梯度。

3.2 数学表示

  • 输出层梯度
    权重矩阵梯度
    ∂L∂Why=∑t=1T∂L∂otht⊤ \frac{\partial L}{\partial W_{hy}} = \sum_{t=1}^T \frac{\partial L}{\partial o_t} h_t^\top WhyL=t=1TotLht
    偏置梯度
    ∂L∂by=∑t=1T∂L∂ot\frac{\partial L}{\partial b_y} = \sum_{t=1}^{T} \frac{\partial L}{\partial o_t}byL=t=1TotL
  • 隐藏层梯度
    ∂L∂Whh=∑t=1T∂L∂ht∂ht∂Whh \frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W_{hh}} WhhL=t=1ThtLWhhht
    ∂L∂Wxh=∑t=1T∂L∂ht∂ht∂Wxh \frac{\partial L}{\partial W_{xh}} = \sum_{t=1}^T \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W_{xh}} WxhL=t=1ThtLWxhht
    ∂L∂bh=∑t=1T∂L∂ht∂ht∂bh\frac{\partial L}{\partial b_h} = \sum_{t=1}^T \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial b_h}bhL=t=1ThtLbhht

4. 反向传播链式法则示例

4.1 问题描述

考虑一个简单的RNN,序列长度T=2T=2T=2,隐藏状态维度dh=2d_h=2dh=2,输入维度dx=1d_x=1dx=1,输出维度dy=1d_y=1dy=1

4.2 前向传播

  1. 时间步t=1t=1t=1
    h1=tanh⁡(Whhh0+Wxhx1+bh) h_1 = \tanh(W_{hh} h_0 + W_{xh} x_1 + b_h) h1=tanh(Whhh0+Wxhx1+bh)
    o1=Whyh1+by o_1 = W_{hy} h_1 + b_y o1=Whyh1+by
    y^1=softmax(o1) \hat{y}_1 = \text{softmax}(o_1) y^1=softmax(o1)
  2. 时间步t=2t=2t=2
    h2=tanh⁡(Whhh1+Wxhx2+bh) h_2 = \tanh(W_{hh} h_1 + W_{xh} x_2 + b_h) h2=tanh(Whhh1+Wxhx2+bh)
    o2=Whyh2+by o_2 = W_{hy} h_2 + b_y o2=Whyh2+by
    y^2=softmax(o2) \hat{y}_2 = \text{softmax}(o_2) y^2=softmax(o2)

4.3 反向传播

  1. 计算损失函数LLLo2o_2o2的梯度
    ∂L∂o2=y^2−y2 \frac{\partial L}{\partial o_2} = \hat{y}_2 - y_2 o2L=y^2y2
  2. 计算损失函数LLLh2h_2h2的梯度
    ∂L∂h2=∂L∂o2∂o2∂h2 \frac{\partial L}{\partial h_2} = \frac{\partial L}{\partial o_2} \frac{\partial o_2}{\partial h_2} h2L=o2Lh2o2
  3. 计算损失函数LLLh1h_1h1的梯度
    ∂L∂h1=∂L∂h2∂h2∂h1+∂L∂o1∂o1∂h1 \frac{\partial L}{\partial h_1} = \frac{\partial L}{\partial h_2} \frac{\partial h_2}{\partial h_1} + \frac{\partial L}{\partial o_1} \frac{\partial o_1}{\partial h_1} h1L=h2Lh1h2+o1Lh1o1
  4. 计算损失函数LLL对参数Whh,Wxh,Why,bh,byW_{hh}, W_{xh}, W_{hy}, b_h, b_yWhh,Wxh,Why,bh,by的梯度

5. 数学附录

5.1 梯度计算

对于隐藏状态hth_tht的梯度:
∂L∂ht=∂L∂ot∂ot∂ht+∂L∂ht+1∂ht+1∂ht \frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial o_t} \frac{\partial o_t}{\partial h_t} + \frac{\partial L}{\partial h_{t+1}} \frac{\partial h_{t+1}}{\partial h_t} htL=otLhtot+ht+1Lhtht+1
其中:
∂ht+1∂ht=Whh⊤diag(1−ht2) \frac{\partial h_{t+1}}{\partial h_t} = W_{hh}^\top \text{diag}(1 - h_t^2) htht+1=Whhdiag(1ht2)

5.2 参数更新

使用梯度下降法更新参数:
W=W−η∂L∂W W = W - \eta \frac{\partial L}{\partial W} W=WηWL
其中η\etaη为学习率。


6. 总结

  • RNN通过前向传播计算隐藏状态和输出,通过反向传播计算梯度并更新参数。
  • 链式法则是反向传播的核心,用于计算损失函数对参数的梯度。
  • 梯度消失和梯度爆炸是RNN训练中的常见问题,可通过LSTM、GRU等改进模型缓解。
Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐