图像分割基础
大约 19 分钟约 5619 字
图像分割基础
简介
图像分割的目标,是为图像中的每个像素分配类别或实例归属,因此它比分类和目标检测更细粒度。分割常用于医学影像、工业缺陷检测、自动驾驶路面理解、遥感地物识别和抠图场景,对边界质量、类别不平衡和标注一致性都非常敏感。
图像分割的发展历程可以概括为以下几个阶段:早期方法(阈值分割、区域生长、分水岭算法)依赖手工特征和启发式规则,泛化能力有限;深度学习时代的开创性工作 FCN(2015)首次将分类网络改造为全卷积网络,实现端到端的像素级预测;U-Net(2015)引入编码器-解码器结构和跳跃连接,成为医学图像分割的标杆;DeepLab 系列(2015-2018)通过空洞空间金字塔池化(ASPP)解决多尺度问题;Mask R-CNN(2017)将实例分割与目标检测统一在一个框架中;近年来,Transformer 架构(SETR、SegFormer、Mask2Former)进一步推动了分割的性能边界。
从数学角度看,图像分割是一个密集预测问题:给定输入图像 I ∈ R^{H×W×3},输出是同尺寸的标注图 Y ∈ R^{H×W}(语义分割)或一组实例掩码 {M_1, M_2, ..., M_N}(实例分割)。这与分类的单一输出和检测的稀疏输出有本质区别,对计算资源的要求也更高。
特点
三种分割任务的对比
| 维度 | 语义分割 | 实例分割 | 全景分割 |
|---|---|---|---|
| 输出 | 每个像素的类别标签 | 每个实例的像素掩码 | 语义 + 实例的统一输出 |
| 同类区分 | 不区分同类实例 | 区分每个实例 | "stuff"类用语义,"thing"类用实例 |
| 典型模型 | U-Net, DeepLab, SegFormer | Mask R-CNN, SOLOv2 | Mask2Former, Panoptic-DeepLab |
| 标注成本 | 中等 | 高 | 最高 |
| 应用场景 | 自动驾驶路面理解 | 细胞计数、行人分割 | 机器人导航、AR |
实现
语义分割:U-Net 基础示例
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class SimpleUNet(nn.Module):
def __init__(self, in_ch=3, out_ch=2):
super().__init__()
self.enc1 = DoubleConv(in_ch, 32)
self.pool = nn.MaxPool2d(2)
self.enc2 = DoubleConv(32, 64)
self.up = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.dec1 = DoubleConv(64, 32)
self.head = nn.Conv2d(32, out_ch, kernel_size=1)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(self.pool(x1))
x3 = self.up(x2)
x4 = torch.cat([x3, x1], dim=1)
x5 = self.dec1(x4)
return self.head(x5)
model = SimpleUNet()
image = torch.randn(2, 3, 256, 256)
logits = model(image)
print(logits.shape) # [B, C, H, W]# 语义分割标签通常是 [H, W] 的类别图
mask = torch.randint(0, 2, (2, 256, 256))
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, mask)
print(float(loss))# 预测类别图
pred_mask = torch.argmax(logits, dim=1)
print(pred_mask.shape)完整 U-Net 实现
import torch
import torch.nn as nn
class UNet(nn.Module):
"""完整 U-Net 实现
U-Net 的核心设计思想:
1. 编码器(下采样路径):逐步提取高语义特征,同时降低空间分辨率
2. 解码器(上采样路径):逐步恢复空间分辨率
3. 跳跃连接(Skip Connections):将编码器的浅层特征与解码器拼接
- 浅层特征包含丰富的空间细节(边缘、纹理)
- 深层特征包含高级语义信息(类别)
- 拼接两者可以同时获得好的语义和好的边界
U-Net 最初是为医学图像分割设计的,其对称结构和小样本表现优秀。
在实际应用中,编码器通常用预训练的分类网络(如 ResNet、EfficientNet)。
"""
def __init__(self, in_channels=3, num_classes=2, base_channels=64):
super().__init__()
self.in_channels = in_channels
self.num_classes = num_classes
# 编码器(下采样路径)
self.enc1 = self._double_conv(in_channels, base_channels)
self.enc2 = self._double_conv(base_channels, base_channels * 2)
self.enc3 = self._double_conv(base_channels * 2, base_channels * 4)
self.enc4 = self._double_conv(base_channels * 4, base_channels * 8)
self.pool = nn.MaxPool2d(2)
# 瓶颈层
self.bottleneck = self._double_conv(base_channels * 8, base_channels * 16)
# 解码器(上采样路径)
self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
self.dec4 = self._double_conv(base_channels * 16, base_channels * 8)
self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
self.dec3 = self._double_conv(base_channels * 8, base_channels * 4)
self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
self.dec2 = self._double_conv(base_channels * 4, base_channels * 2)
self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
self.dec1 = self._double_conv(base_channels * 2, base_channels)
# 分类头
self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
def _double_conv(self, in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
# 编码器
e1 = self.enc1(x) # (B, 64, H, W)
e2 = self.enc2(self.pool(e1)) # (B, 128, H/2, W/2)
e3 = self.enc3(self.pool(e2)) # (B, 256, H/4, W/4)
e4 = self.enc4(self.pool(e3)) # (B, 512, H/8, W/8)
# 瓶颈
b = self.bottleneck(self.pool(e4)) # (B, 1024, H/16, W/16)
# 解码器 + 跳跃连接
d4 = self.upconv4(b)
d4 = torch.cat([d4, e4], dim=1)
d4 = self.dec4(d4) # (B, 512, H/8, W/8)
d3 = self.upconv3(d4)
d3 = torch.cat([d3, e3], dim=1)
d3 = self.dec3(d3) # (B, 256, H/4, W/4)
d2 = self.upconv2(d3)
d2 = torch.cat([d2, e2], dim=1)
d2 = self.dec2(d2) # (B, 128, H/2, W/2)
d1 = self.upconv1(d2)
d1 = torch.cat([d1, e1], dim=1)
d1 = self.dec1(d1) # (B, 64, H, W)
return self.head(d1) # (B, num_classes, H, W)
model = UNet(in_channels=3, num_classes=21, base_channels=32)
x = torch.randn(2, 3, 256, 256)
out = model(x)
print(f"U-Net 输出: {out.shape}")
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")DeepLab V3+ 的核心思想
import torch
import torch.nn as nn
class ASPP(nn.Module):
"""空洞空间金字塔池化(Atrous Spatial Pyramid Pooling)
ASPP 是 DeepLab 系列的核心模块,解决多尺度分割问题。
核心思想:在不同膨胀率下并行应用空洞卷积,捕获不同感受野的信息。
- 1x1 卷积:感受野 1(点级特征)
- 3x3 空洞卷积 rate=6:感受野 13
- 3x3 空洞卷积 rate=12:感受野 25
- 3x3 空洞卷积 rate=18:感受野 37
- 全局平均池化:感受野 = 整张图
为什么不用更大的卷积核或池化?
- 大卷积核参数量大、计算慢
- 池化会丢失空间信息
- 空洞卷积在不增加参数的情况下扩大感受野
"""
def __init__(self, in_channels, out_channels=256, rates=(6, 12, 18)):
super().__init__()
modules = []
# 1x1 卷积
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# 不同膨胀率的空洞卷积
for rate in rates:
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=rate,
dilation=rate, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# 全局平均池化
modules.append(nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5)
)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
# 将全局池化结果上采样回原始尺寸
res[-1] = nn.functional.interpolate(res[-1], size=x.shape[2:], mode='bilinear', align_corners=False)
return self.project(torch.cat(res, dim=1))
aspp = ASPP(2048, 256)
x = torch.randn(1, 2048, 32, 32)
out = aspp(x)
print(f"ASPP 输入: {x.shape} -> 输出: {out.shape}")Dice Loss、IoU 与边界评估
import torch.nn.functional as F
def dice_loss(logits, targets, smooth=1.0):
probs = torch.softmax(logits, dim=1)[:, 1] # 取前景类概率
targets = targets.float()
intersection = (probs * targets).sum()
union = probs.sum() + targets.sum()
dice = (2 * intersection + smooth) / (union + smooth)
return 1 - dice
loss = dice_loss(logits, mask)
print(float(loss))def iou_score(pred, target):
pred = pred.bool()
target = target.bool()
intersection = (pred & target).sum().item()
union = (pred | target).sum().item()
return intersection / union if union > 0 else 1.0
score = iou_score(pred_mask[0] == 1, mask[0] == 1)
print("iou:", score)全面的损失函数集合
import torch
import torch.nn as nn
import torch.nn.functional as F
class SegmentationLosses:
"""分割任务的常用损失函数
不同损失函数适合不同场景:
1. CrossEntropyLoss: 多分类基线,类别均衡时首选
2. BCEWithLogitsLoss: 二分类分割,配合 sigmoid
3. Dice Loss: 前景面积小时效果显著(医学分割常用)
4. Focal Loss: 严重类别不平衡(背景 >> 前景)
5. Tversky Loss: Dice 的泛化版本,可调节 FP/FN 的权重
6. Boundary Loss: 专门优化边界质量
7. Lovasz Loss: 直接优化 IoU 指标
"""
@staticmethod
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
"""Focal Loss: 关注困难样本
FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
gamma > 0 时,对容易分类的样本(p_t 接近 1)降低权重
当 gamma=2 时,p_t=0.9 的样本权重只有 p_t=0.5 的 1/81
alpha 用于平衡正负样本的比例
"""
ce_loss = F.cross_entropy(logits, targets, reduction='none')
p_t = torch.exp(-ce_loss)
focal_weight = alpha * (1 - p_t) ** gamma
return (focal_weight * ce_loss).mean()
@staticmethod
def tversky_loss(logits, targets, alpha=0.7, beta=0.3, smooth=1.0):
"""Tversky Loss: Dice 的泛化版本
Tversky Index = TP / (TP + alpha*FP + beta*FN)
alpha > beta: 更关注减少假阴性(FN)-> 适合医学分割
beta > alpha: 更关注减少假阳性(FP)-> 适合检测场景
alpha = beta = 0.5: 等价于 Dice Loss
"""
num_classes = logits.shape[1]
targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()
probs = torch.softmax(logits, dim=1)
tp = (probs * targets_one_hot).sum(dim=(2, 3))
fp = (probs * (1 - targets_one_hot)).sum(dim=(2, 3))
fn = ((1 - probs) * targets_one_hot).sum(dim=(2, 3))
tversky = (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)
return 1 - tversky.mean()
@staticmethod
def boundary_loss(logits, targets):
"""边界损失:专门优化边界质量
思路:提取标签和预测的边界,计算边界上的差异。
对医学分割特别重要,因为边界精度直接影响临床决策。
"""
# 使用 Sobel 算子提取边界
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32, device=logits.device).view(1, 1, 3, 3)
sobel_y = sobel_x.transpose(-1, -2)
# 提取标签边界
targets_float = targets.float().unsqueeze(1)
boundary_x = F.conv2d(targets_float, sobel_x, padding=1)
boundary_y = F.conv2d(targets_float, sobel_y, padding=1)
boundary_gt = torch.sqrt(boundary_x**2 + boundary_y**2)
boundary_gt = (boundary_gt > 0).float()
# 提取预测边界
pred = torch.argmax(logits, dim=1).float().unsqueeze(1)
boundary_x_pred = F.conv2d(pred, sobel_x, padding=1)
boundary_y_pred = F.conv2d(pred, sobel_y, padding=1)
boundary_pred = torch.sqrt(boundary_x_pred**2 + boundary_y_pred**2)
boundary_pred = (boundary_pred > 0).float()
# 边界上的差异
boundary_diff = F.binary_cross_entropy(boundary_pred, boundary_gt, reduction='mean')
return boundary_diff
@staticmethod
def combined_loss(logits, targets, ce_weight=1.0, dice_weight=1.0):
"""组合损失:CE + Dice 的组合是分割任务的最常用方案
CE Loss 提供稳定的梯度信号,保证收敛
Dice Loss 关注区域重叠度,缓解类别不平衡
两者互补:
- 当 Dice 梯度小时(预测为空或全对),CE 仍然提供梯度
- 当 CE 梯度小时(大面积正确但边界不精确),Dice 提供更细的梯度
"""
ce = F.cross_entropy(logits, targets)
# 计算 Dice
num_classes = logits.shape[1]
targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()
probs = torch.softmax(logits, dim=1)
dice = 1 - (2 * (probs * targets_one_hot).sum() + 1) / (probs.sum() + targets_one_hot.sum() + 1)
return ce_weight * ce + dice_weight * dice
# 使用示例
losses = SegmentationLosses()
logits = torch.randn(2, 5, 64, 64)
targets = torch.randint(0, 5, (2, 64, 64))
print(f"Focal Loss: {losses.focal_loss(logits, targets):.4f}")
print(f"Tversky Loss: {losses.tversky_loss(logits, targets):.4f}")
print(f"Combined Loss: {losses.combined_loss(logits, targets):.4f}")完整的评估指标
import torch
import numpy as np
class SegmentationMetrics:
"""分割任务的全面评估指标
1. Pixel Accuracy (PA): 正确像素数 / 总像素数
- 问题:背景像素多时会虚高,不推荐单独使用
2. Mean IoU (mIoU): 各类别 IoU 的平均值
- 最常用的分割指标
- 对类别不平衡敏感
3. Dice Coefficient (F1): 2*TP / (2*TP + FP + FN)
- 医学分割的标准指标
- 与 IoU 单调相关:Dice = 2*IoU / (1+IoU)
4. Frequency Weighted IoU (FWIoU): 按类别频率加权的 IoU
- 考虑了类别频率
5. Boundary IoU: 只在边界区域计算 IoU
- 评估边界质量
"""
@staticmethod
def compute_miou(pred, target, num_classes):
"""计算 mIoU(逐类别 IoU 的平均值)"""
miou_per_class = []
for cls in range(num_classes):
pred_cls = (pred == cls)
target_cls = (target == cls)
intersection = (pred_cls & target_cls).sum().item()
union = (pred_cls | target_cls).sum().item()
iou = intersection / union if union > 0 else float('nan')
miou_per_class.append(iou)
return np.nanmean(miou_per_class), miou_per_class
@staticmethod
def compute_dice(pred, target, num_classes):
"""计算各类别 Dice 系数"""
dice_per_class = []
for cls in range(num_classes):
pred_cls = (pred == cls)
target_cls = (target == cls)
tp = (pred_cls & target_cls).sum().item()
fp_fn = (pred_cls ^ target_cls).sum().item()
dice = 2 * tp / (2 * tp + fp_fn) if (2 * tp + fp_fn) > 0 else float('nan')
dice_per_class.append(dice)
return np.nanmean(dice_per_class), dice_per_class
@staticmethod
def compute_boundary_iou(pred, target, dilation_radius=2):
"""在边界区域计算 IoU,评估边界质量"""
from scipy.ndimage import binary_dilation
pred_np = pred.cpu().numpy()
target_np = target.cpu().numpy()
# 提取边界
boundary_target = binary_dilation(target_np, iterations=dilation_radius) ^ target_np
boundary_pred = binary_dilation(pred_np, iterations=dilation_radius) ^ pred_np
intersection = (boundary_target & boundary_pred).sum()
union = (boundary_target | boundary_pred).sum()
return intersection / union if union > 0 else 0.0
@staticmethod
def confusion_matrix(pred, target, num_classes):
"""计算分割混淆矩阵"""
mask = (target >= 0) & (target < num_classes)
hist = np.bincount(
num_classes * target[mask].cpu().numpy().astype(int) + pred[mask].cpu().numpy().astype(int),
minlength=num_classes ** 2
).reshape(num_classes, num_classes)
return hist
# 使用示例
pred = torch.randint(0, 5, (256, 256))
target = torch.randint(0, 5, (256, 256))
metrics = SegmentationMetrics()
miou, per_class_iou = metrics.compute_miou(pred, target, num_classes=5)
dice, per_class_dice = metrics.compute_dice(pred, target, num_classes=5)
print(f"mIoU: {miou:.4f}")
print(f"各类别 IoU: {[f'{x:.4f}' for x in per_class_iou]}")
print(f"Mean Dice: {dice:.4f}")常见损失函数:
- CrossEntropyLoss:多分类基础选择
- Dice Loss:适合前景面积很小的场景
- Focal Loss:适合类别不平衡
- BCE + Dice:二分类分割常见组合实例分割与 Mask R-CNN 思路
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image
mask_model = maskrcnn_resnet50_fpn(weights="DEFAULT")
mask_model.eval()
img = Image.new("RGB", (512, 512), color="white")
img_tensor = F.to_tensor(img)
with torch.no_grad():
outputs = mask_model([img_tensor])[0]
print(outputs.keys())
print(outputs["boxes"][:2])
print(outputs["masks"][:2].shape)# 二值化实例 mask
if len(outputs["masks"]) > 0:
binary_mask = (outputs["masks"][0, 0] > 0.5).int()
print(binary_mask.shape)分割任务分类:
- 语义分割:每个像素只关心类别,不区分同类实例
- 实例分割:同类目标也要分开
- 全景分割:同时兼顾语义分割与实例分割数据增强策略
import torch
import numpy as np
class SegmentationAugmentation:
"""分割任务的数据增强
注意:增强操作必须同时作用于图像和掩码!
几何变换(翻转、旋转、缩放)对图像和掩码完全一致
颜色变换只作用于图像
常用增强:
1. 随机翻转(水平/垂直)
2. 随机旋转(±15°)
3. 随机缩放和裁剪
4. 弹性变形(医学分割常用)
5. 颜色抖动(亮度、对比度、饱和度)
6. MixUp / CutMix(需要特殊的标签处理)
"""
@staticmethod
def random_flip(image, mask, p=0.5):
"""随机水平翻转"""
if np.random.random() < p:
image = image.flip(-1)
mask = mask.flip(-1)
return image, mask
@staticmethod
def random_rotate(image, mask, max_degree=15):
"""随机旋转"""
import torchvision.transforms.functional as TF
angle = np.random.uniform(-max_degree, max_degree)
image = TF.rotate(image, angle)
mask = TF.rotate(mask.float(), angle).round().long()
return image, mask
@staticmethod
def random_scale_crop(image, mask, scale_range=(0.8, 1.2), crop_size=256):
"""随机缩放后裁剪到固定尺寸"""
import torchvision.transforms.functional as TF
scale = np.random.uniform(*scale_range)
h, w = image.shape[-2:]
new_h, new_w = int(h * scale), int(w * scale)
image = TF.resize(image, [new_h, new_w])
mask = TF.resize(mask.float().unsqueeze(1), [new_h, new_w]).squeeze(1).long()
# 随机裁剪
i = np.random.randint(0, max(0, new_h - crop_size) + 1)
j = np.random.randint(0, max(0, new_w - crop_size) + 1)
image = image[:, i:i+crop_size, j:j+crop_size]
mask = mask[i:i+crop_size, j:j+crop_size]
return image, mask
# 使用示例
image = torch.randn(3, 256, 256)
mask = torch.randint(0, 5, (256, 256))
aug = SegmentationAugmentation()
image_flipped, mask_flipped = aug.random_flip(image, mask)
print(f"翻转后: image={image_flipped.shape}, mask={mask_flipped.shape}")类别不平衡处理
import torch
import torch.nn as nn
import torch.nn.functional as F
class ClassBalancedLoss(nn.Module):
"""处理分割中的类别不平衡
类别不平衡在分割任务中非常普遍:
- 背景像素通常占 80-90%
- 小目标(如血管、细胞)可能只占不到 1%
处理方法:
1. 加权交叉熵:给少数类更大的权重
2. Focal Loss:降低容易分类样本的权重
3. OHEM(在线困难样本挖掘):只对困难样本计算损失
4. 过采样/欠采样:在 batch 层面平衡
"""
@staticmethod
def weighted_ce(logits, targets, class_weights=None):
"""加权交叉熵"""
if class_weights is None:
# 使用逆频率作为权重
num_classes = logits.shape[1]
class_counts = torch.bincount(targets.flatten(), minlength=num_classes).float()
total = class_counts.sum()
class_weights = total / (num_classes * class_counts + 1e-6)
print(f"类别权重: {class_weights}")
weight_tensor = torch.tensor(class_weights, dtype=logits.dtype, device=logits.device)
return F.cross_entropy(logits, targets, weight=weight_tensor)
@staticmethod
def ohem_loss(logits, targets, ohem_ratio=0.3):
"""在线困难样本挖掘(OHEM)
只对损失值最大的前 ohem_ratio 比例的像素计算损失。
这样可以让模型更关注难学的前景像素。
"""
ce_loss = F.cross_entropy(logits, targets, reduction='none')
flat_loss = ce_loss.flatten()
num_keep = int(len(flat_loss) * ohem_ratio)
topk_loss, _ = flat_loss.topk(num_keep)
return topk_loss.mean()
# 使用示例
logits = torch.randn(2, 5, 64, 64)
targets = torch.randint(0, 5, (2, 64, 64))
# 模拟不平衡:让类别 0 占大多数
targets[targets > 2] = 0
print(f"加权 CE Loss: {ClassBalancedLoss.weighted_ce(logits, targets):.4f}")
print(f"OHEM Loss: {ClassBalancedLoss.ohem_loss(logits, targets):.4f}")推理优化与后处理
import torch
import torch.nn.functional as F
class SegmentationInference:
"""分割推理的实用技巧
1. 滑动窗口推理:处理超高分辨率图像
2. 测试时增强(TTA):翻转、多尺度推理取平均
3. CRF 后处理:优化分割边界的连贯性
"""
@staticmethod
def sliding_window_inference(model, image, crop_size=512, overlap=64):
"""滑动窗口推理
当输入图像分辨率超过 GPU 显存时,使用滑动窗口逐步推理。
重叠区域取平均以消除拼接痕迹。
Args:
model: 分割模型
image: (1, C, H, W) 输入图像
crop_size: 每次推理的裁剪尺寸
overlap: 重叠区域大小
"""
_, _, h, w = image.shape
output = torch.zeros(1, model.num_classes if hasattr(model, 'num_classes') else 2, h, w)
count = torch.zeros(1, 1, h, w)
stride = crop_size - overlap
for y in range(0, h - crop_size + 1, stride):
for x in range(0, w - crop_size + 1, stride):
crop = image[:, :, y:y+crop_size, x:x+crop_size]
with torch.no_grad():
pred = model(crop)
output[:, :, y:y+crop_size, x:x+crop_size] += pred
count[:, :, y:y+crop_size, x:x+crop_size] += 1
# 处理边缘未覆盖区域
if h % stride != 0 or w % stride != 0:
for y in range(max(0, h - crop_size), h):
for x in range(max(0, w - crop_size), w):
# 简化处理:直接用最近的裁剪结果
pass
return output / count.clamp(min=1)
@staticmethod
def test_time_augmentation(model, image, scales=(0.75, 1.0, 1.25), flip=True):
"""测试时增强(TTA)
在推理时使用多种变换,对结果取平均。
通常能提升 0.5-1.5% mIoU。
缺点:推理时间翻倍
"""
preds = []
for scale in scales:
h, w = image.shape[-2:]
new_h, new_w = int(h * scale), int(w * scale)
scaled = F.interpolate(image, size=(new_h, new_w), mode='bilinear', align_corners=False)
with torch.no_grad():
pred = model(scaled)
pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=False)
preds.append(pred)
if flip:
with torch.no_grad():
pred_flip = model(scaled.flip(-1)).flip(-1)
pred_flip = F.interpolate(pred_flip, size=(h, w), mode='bilinear', align_corners=False)
preds.append(pred_flip)
# 对 logits 取平均
avg_pred = torch.stack(preds).mean(dim=0)
return avg_pred
print("分割推理技巧已定义")优点
缺点
总结
图像分割最难的地方,不只是模型结构,而是任务定义、标注质量和边界细节。很多项目的瓶颈并不是网络不够深,而是数据标签质量不足、类别分布失衡或者部署分辨率和实验分辨率不一致。
在实践中,建议先从简单的基线开始(如 U-Net + CE + Dice Loss),建立完整的训练-评估-推理流程,然后逐步优化数据增强、损失函数和模型架构。不要一开始就追求最复杂的方案。
关键知识点
- 分割输出是像素级别,标签质量影响极大。
- Dice / IoU 比简单 accuracy 更能反映真实效果。
- 语义分割与实例分割要先明确任务边界。
- 类别不平衡和边界样本通常是训练难点。
- U-Net 的跳跃连接是保留空间细节的关键设计。
- ASPP 通过多尺度空洞卷积捕获不同感受野的特征。
- 组合损失(CE + Dice)比单一损失通常效果更好。
项目落地视角
- 医疗影像分割更关注 Dice 和边界一致性。
- 工业缺陷分割常用于面积统计和缺陷轮廓定位。
- 自动驾驶语义分割需要兼顾速度和多类别场景理解。
- 抠图、背景替换、遥感地物分类都常用分割模型。
医学分割的特殊考虑
def medical_segmentation_considerations():
"""医学影像分割的特殊注意事项
1. 数据特点:
- 样本量小(几十到几百例)
- 类别极度不平衡(病灶区域极小)
- 标注需要专业知识,标注者间差异大
- 不同设备的成像参数不同(模态差异)
2. 评估重点:
- Dice > 0.8 通常被认为临床可用
- Hausdorff Distance (HD95): 评估最远边界点距离
- Surface Dice: 边界表面的重叠度
3. 特殊技巧:
- 交叉验证(5-fold CV)代替简单的 train/val split
- 预训练编码器 + 小学习率微调
- 弹性变形和随机仿射变换增强
- 半监督/自监督预训练
"""
print("医学分割的特殊考虑:")
print(" 1. 使用 5-fold 交叉验证评估")
print(" 2. 重点关注 Dice 系数和 HD95")
print(" 3. 使用弹性变形增强数据")
print(" 4. 考虑标注者间的一致性")
print(" 5. 不同模态(CT/MRI)的归一化方式不同")
medical_segmentation_considerations()常见误区
- 把检测任务与分割任务混为一谈,指标和标注都用错。
- 只看整体 mIoU,不分析小类和边界失败样本。
- 输入分辨率一变,模型效果就断崖下跌,却未单独验证。
- 数据标签不稳定时,一味堆模型结构。
- 使用像素精度(Pixel Accuracy)作为主要指标——会被背景主导。
- 忘记对掩码做数据增强——导致几何变换不一致。
- 训练时使用一种分辨率,推理时使用另一种分辨率。
进阶路线
- 学习 DeepLab、U-Net++、Mask R-CNN、SegFormer 等架构。
- 研究边界损失、类别重加权、伪标签和半监督分割。
- 将分割与检测、OCR、跟踪结合做复合视觉任务。
- 使用 MMSegmentation、Detectron2 等框架系统化训练。
- 探索 SAM(Segment Anything Model)的零样本分割能力。
- 学习分割大模型如何通过 prompt 和 fine-tuning 适应新任务。
适用场景
- 医疗影像病灶区域分割。
- 工业缺陷区域检测与面积统计。
- 自动驾驶路面、车道线、障碍物理解。
- 遥感图像分类、抠图、背景分离等任务。
落地建议
- 先定义清楚是语义分割、实例分割还是全景分割。
- 对小目标、细边界、少数类做专门数据分析与增强。
- 训练和部署保持尽量一致的输入分辨率与预处理。
- 为关键类别分别追踪 IoU / Dice,而不是只看总分。
- 使用预训练编码器加速收敛并提升效果。
排错清单
- 检查掩码标签是否存在错位、漏标或边界粗糙问题。
- 检查前景/背景比例是否严重失衡。
- 检查损失函数是否适合当前任务与类别分布。
- 检查推理阶段阈值、后处理和缩放方式是否与训练一致。
- 检查图像和掩码的增强操作是否一致。
- 检查模型在验证集上是否过拟合——考虑增加正则化或数据增强。
复盘问题
- 你要解决的是语义区域识别,还是实例分离?
- 当前效果差,主要是边界问题、少数类问题还是标签问题?
- 指标提升后,是否真的提升了业务所关心的区域质量?
- 如果部署到更高分辨率场景,模型还能稳定工作吗?
