ResNet 经典网络详解:深度学习中的革命性突破(含代码)
ResNet:解决深度网络训练难题的核心技术
在深度学习的发展历程中,ResNet(残差网络)是一个重要的突破。它成功解决了深层神经网络训练中的梯度消失问题,使得深度模型能够达到更好的性能。本文将以 ResNet-18 为例,详细解析其网络结构、关键技术,并探讨其实际应用。
1. ResNet-18 的结构
ResNet-18 的整体结构由多个阶段组成,每个阶段包含若干个 残差块(ResBlock)。每个残差块内部有两次卷积操作,其中包括 3×3 卷积 和 批量归一化(BatchNorm),并在输出端通过 ReLU 激活函数 进行非线性变换。
下面ResNet-18 结构图
具体来说,ResNet-18 由以下 6 个主要阶段组成:
- Stage 1:7×7 卷积层 + 最大池化,初步提取特征。
- Stage 2:两个 实线 ResBlock,输入输出尺寸不变。
- Stage 3:一个 虚线 ResBlock(尺寸减半) + 一个 实线 ResBlock。
- Stage 4:一个 虚线 ResBlock(尺寸减半) + 一个 实线 ResBlock。
- Stage 5:一个 虚线 ResBlock(尺寸减半) + 一个 实线 ResBlock。
- Stage 6:全局平均池化(Global Average Pooling)+ 全连接层,用于分类任务。
2. 为什么需要 ResNet?
在神经网络的训练过程中,主要面临两个挑战:
- 梯度消失问题(Vanishing Gradient)
- 在传统的深层网络中,随着网络层数增加,梯度在反向传播时会逐渐减小,导致靠近输入层的参数更新变慢,影响模型的训练效果。
- 退化问题(Degradation Problem)
- 直觉上,我们希望增加网络深度可以提高模型性能。然而实验发现,当层数足够深时,训练误差反而增大,说明模型学习能力受限。
🌟 解决方案:残差连接(Skip Connection)
ResNet 提出了 残差学习(Residual Learning),其核心思想是 直接让信息跨层传输,避免梯度消失,并帮助网络学习更深层次的特征。
残差块(ResBlock)的数学表达式如下:
3. 关键技术解析:残差块(ResBlock)
ResNet 的核心在于 残差块(ResBlock),它由两个卷积层组成,并且引入了 Shortcut Connection(跳跃连接)。
🔹 残差块的 PyTorch 实现
import torch
from torch import nn
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 只有 stride=2 时,调整 shortcut 的尺寸
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) if stride == 2 else nn.Identity()
def forward(self, x):
shortcut = self.shortcut(x)
x = torch.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return torch.relu(x + shortcut) # 残差连接
- 实线残差块(stride=1):保持输入和输出尺寸不变。
- 虚线残差块(stride=2):在 shortcut 处使用 1×1 卷积来降低特征图尺寸,使得输入和输出可以相加。
4. ResNet-18 的完整 PyTorch 实现
class ResNet18(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.stage1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.stage2 = nn.Sequential(ResBlock(64, 64, stride=1), ResBlock(64, 64, stride=1))
self.stage3 = nn.Sequential(ResBlock(64, 128, stride=2), ResBlock(128, 128, stride=1))
self.stage4 = nn.Sequential(ResBlock(128, 256, stride=2), ResBlock(256, 256, stride=1))
self.stage5 = nn.Sequential(ResBlock(256, 512, stride=2), ResBlock(512, 512, stride=1))
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.global_pool(x)
x = torch.flatten(x, 1)
return self.fc(x)
# 测试 ResNet-18
model = ResNet18()
X = torch.randn(1, 3, 224, 224)
print(model(X).shape) # 输出 (1, 10)
5. ResNet 的应用场景
ResNet 不仅是学术研究中的一个里程碑,同时在 计算机视觉 领域有着广泛的应用,尤其适用于以下场景:
🔹 图像分类
ResNet 是 ImageNet 竞赛中的标杆模型,被广泛用于 人脸识别、医学影像分类 等任务。例如,在 CT/MRI 诊断 中,ResNet 被用于 肿瘤检测、肺炎识别,帮助医生提高诊断精度。
🔹 目标检测
在 Faster R-CNN、YOLO、Mask R-CNN 等目标检测算法中,ResNet 经常作为 骨干网络(Backbone),用于自动驾驶、安防监控等。例如,特斯拉的 自动驾驶系统 可能会使用类似的 CNN 结构进行目标检测。
🔹 语义分割
结合 FCN、DeepLab、U-Net,ResNet 在 遥感影像、医学影像 领域被用于精准的目标分割。例如,城市规划 中,可以利用 ResNet 进行 卫星图像分析,识别建筑、道路等结构。
🔹 超分辨率 & 图像风格迁移
ResNet 也被应用于 超分辨率重建(Super-Resolution) 和 图像风格迁移,如老照片修复、高清画质增强等。
6. 结论
ResNet-18 是深度学习领域中最经典的网络之一,尤其在图像分类任务上表现优异。它通过 残差连接 解决了梯度消失问题,使得更深的神经网络能够被成功训练。
如果你对 ResNet 还有更深入的兴趣,可以尝试 ResNet-50、ResNet-101 这些更深层次的变体,或者结合 Transformer、注意力机制等技术进行优化!