神经网络防"失忆"秘籍:弹性权重固化如何让AI学会"温故知新"
“就像学霸给重点笔记贴荧光标签,EWC给重要神经网络参数上锁”
一、核心公式对比表
公式名称 | 数学表达式 | 通俗解释 | 类比场景 | 文献 |
---|---|---|---|---|
EWC主公式 | L t o t a l = L n e w + λ 2 ∑ i F i ( θ i − θ o l d , i ) 2 L_{total} = L_{new} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{old,i})^2 Ltotal=Lnew+2λ∑iFi(θi−θold,i)2 | 给重要知识上锁 | 给重点笔记贴荧光标签 | |
贝叶斯推导式 | log p ( θ ∣ D ) ∝ log p ( D B ∣ θ ) + log p ( θ ∣ D A ) \log p(\theta|D) \propto \log p(D_B|\theta) + \log p(\theta|D_A) logp(θ∣D)∝logp(DB∣θ)+logp(θ∣DA) | 新旧知识平衡法则 | 考试前既复习新题也温习旧题 | |
费舍尔信息矩阵 | F i = E [ ∇ θ 2 log p ( y ∣ x , θ ) ] F_i = \mathbb{E}[\nabla_\theta^2 \log p(y|x,\theta)] Fi=E[∇θ2logp(y∣x,θ)] | 知识重要度评分卡 | 根据笔记划重点的频率标记 |
二、公式详解与类比解释
公式1:EWC核心保护机制
L t o t a l = L n e w ( θ ) ⏟ 新任务 + λ 2 ∑ i F i ( θ i − θ o l d , i ) 2 ⏟ 旧知识防护罩 L_{total} = \underbrace{L_{new}(\theta)}_{\text{新任务}} + \underbrace{\frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{old,i})^2}_{\text{旧知识防护罩}} Ltotal=新任务 Lnew(θ)+旧知识防护罩 2λi∑Fi(θi−θold,i)2
参数 | 数学符号 | 类比解释 | 作用原理 |
---|---|---|---|
旧任务参数 | θ o l d \theta_{old} θold | 学霸的旧笔记本 | 知识基准锚点 |
重要性系数 | F i F_i Fi | 荧光标签密度 | 参数重要性量化 |
约束强度 | λ \lambda λ | 胶水粘性系数 | 平衡新旧知识权重 |
案例:在图像分类任务中,给识别"猫耳朵"的关键神经元增加3倍保护权重
公式2:费舍尔信息矩阵计算
F
i
=
1
N
∑
x
,
y
(
∂
log
p
(
y
∥
x
,
θ
)
∂
θ
i
)
2
F_i = \frac{1}{N} \sum_{x,y} \left( \frac{\partial \log p(y\|x,\theta)}{\partial \theta_i} \right)^2
Fi=N1x,y∑(∂θi∂logp(y∥x,θ))2
变量解读:
- x x x:输入数据(学生的练习题)
- y y y:标签答案(标准答案)
- log p ( y ∥ x , θ ) \log p(y\|x,\theta) logp(y∥x,θ):答案正确率评分
类比解释:
如同统计学生复习时翻看某页笔记的次数,翻看越频繁的页面(参数)获得越多荧光标签(高F值)
三、公式体系演进(关键推导步骤)
1. 贝叶斯推导路径
- 初始目标:
max θ log p ( θ ∥ D A , D B ) \max_\theta \log p(\theta\|D_A, D_B) θmaxlogp(θ∥DA,DB) - 任务分解:
∝ log p ( D B ∥ θ ) + log p ( θ ∥ D A ) \propto \log p(D_B\|\theta) + \log p(\theta\|D_A) ∝logp(DB∥θ)+logp(θ∥DA) - 拉普拉斯近似:
log p ( θ ∥ D A ) ≈ − 1 2 ∑ i F i ( θ i − θ o l d , i ) 2 \log p(\theta\|D_A) \approx -\frac{1}{2} \sum_i F_i (\theta_i - \theta_{old,i})^2 logp(θ∥DA)≈−21i∑Fi(θi−θold,i)2
2. 方法对比
方法 | 核心公式 | 优势 | 局限 |
---|---|---|---|
L2正则 | L = L n e w + λ ∣ θ − θ o l d ∣ 2 L = L_{new} + \lambda |\theta - \theta_{old}|^2 L=Lnew+λ∣θ−θold∣2 | 简单易实现 | 无差别保护所有参数 |
EWC | L = L n e w + λ 2 ∑ F i ( θ i − θ o l d , i ) 2 L = L_{new} + \frac{\lambda}{2} \sum F_i(\theta_i - \theta_{old,i})^2 L=Lnew+2λ∑Fi(θi−θold,i)2 | 智能参数保护 | 需计算二阶导数 |
LwF | L = L n e w + α D K L ( p o l d ∣ p n e w ) L = L_{new} + \alpha D_{KL}(p_{old}|p_{new}) L=Lnew+αDKL(pold∣pnew) | 保持输出分布稳定 | 依赖旧模型推理 |
四、代码实战:MNIST/FashionMNIST增量学习
import torch
import torch.nn as nn
class EWC_CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1,32,3),
nn.ReLU(),
nn.MaxPool2d(2))
self.fc = nn.Linear(32*13*13,10)
def forward(self, x):
x = self.conv(x)
return self.fc(x.view(x.size(0),-1))
# EWC核心实现
class EWC_Regularizer:
def __init__(self, model, dataloader, device):
self.model = model
self.params = {n:p.detach().clone() for n,p in model.named_parameters()} # 旧参数快照
self.fisher = {}
# 计算Fisher信息矩阵
for batch in dataloader:
inputs, labels = batch
outputs = model(inputs.to(device))
loss = nn.CrossEntropyLoss()(outputs, labels.to(device))
loss.backward()
for n,p in model.named_parameters():
if p.grad is not None:
self.fisher[n] = p.grad.data.pow(2).mean() # 梯度平方均值
def penalty(self, current_params):
loss = 0
for n,p in current_params.items():
loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
return loss
# 训练流程示例
device = torch.device('cuda')
model = EWC_CNN().to(device)
old_task_loader = ... # 旧任务数据加载器
new_task_loader = ... # 新任务数据加载器
# 第一阶段:训练旧任务
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
for batch in old_task_loader:
# 常规训练流程...
# 第二阶段:计算EWC约束
ewc = EWC_Regularizer(model, old_task_loader, device)
# 第三阶段:增量学习新任务
for epoch in range(10):
for batch in new_task_loader:
inputs, labels = batch
outputs = model(inputs.to(device))
ce_loss = nn.CrossEntropyLoss()(outputs, labels.to(device))
ewc_loss = ewc.penalty(dict(model.named_parameters()))
total_loss = ce_loss + 1000 * ewc_loss # λ=1000
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
五、可视化解析
1. 参数空间分布
import matplotlib.pyplot as plt
import seaborn as sns
# 生成模拟数据
theta_old = np.random.normal(0,1,1000) # 旧任务参数分布
theta_new = np.random.normal(3,1,1000) # 新任务参数分布
theta_ewc = 0.3*theta_old + 0.7*theta_new # EWC约束参数
# 可视化
plt.figure(figsize=(12,6))
sns.kdeplot(theta_old, label="Old Task", color='grey', linewidth=3)
sns.kdeplot(theta_new, label="New Task", color='gold', linewidth=3)
sns.kdeplot(theta_ewc, label="EWC Compromise", color='red', linestyle='--')
plt.title("Parameter Space Distribution")
plt.xlabel("Parameter Value"), plt.ylabel("Density")
plt.legend()
plt.show()
六、技术演进路线
阶段 | 代表方法 | 关键突破 | 局限 |
---|---|---|---|
1.0 | 参数冻结 | 物理隔离旧知识 | 丧失模型扩展能力 |
2.0 | L2正则 | 简单约束参数漂移 | 无差别保护所有参数 |
3.0 | EWC | 智能参数重要性加权 | 计算二阶导数开销大 |
4.0 | 动态网络 | 独立适配器模块 | 模型体积膨胀 |