联邦学习基础
联邦学习基础
简介
联邦学习(Federated Learning)是一种分布式训练范式:多个客户端在本地使用各自数据训练模型,再把模型更新而不是原始数据上传到服务端聚合。它特别适合医疗、金融、终端设备、跨机构协作等无法集中原始数据的场景,但并不意味着天然安全或天然高效。
联邦学习的概念由 Google 在 2016 年首次提出(FedAvg 算法),最初用于解决移动端键盘输入预测的隐私问题。核心动机是:数据分散在大量终端设备上,集中传输既不现实(隐私法规、带宽限制)也不必要(只需要模型更新)。随后,联邦学习迅速扩展到医疗、金融等跨机构协作场景,成为隐私计算领域的重要技术方向。
从系统架构的角度看,联邦学习本质上是一种"数据不动模型动"的分布式计算范式。与传统分布式训练(数据分布在多个 worker 上,但模型参数集中管理)不同,联邦学习的每个客户端拥有独立的数据集,且这些数据集的分布可能存在显著差异(Non-IID 问题)。这使得联邦学习不仅是一个算法问题,更是一个涉及通信、安全和系统工程的综合性问题。
特点
联邦学习的三种范式
横向联邦学习(Horizontal FL):各客户端有相同的特征空间但不同的样本。例如多家医院使用相同的检查项目,但各自拥有不同的患者。适用于特征对齐但样本不同的跨机构场景。
纵向联邦学习(Vertical FL):各客户端有相同的样本主体但不同的特征。例如银行拥有用户的收入信息,电商平台拥有用户的购买记录。适用于样本重叠但特征互补的场景。
联邦迁移学习(Federated Transfer Learning):样本和特征都不同,通过迁移学习的方法桥接差异。适用于跨领域协作场景。
实现
FedAvg 基础流程
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 2)
)
def forward(self, x):
return self.net(x)
def local_train(model, dataloader, epochs=1, lr=0.01):
model = copy.deepcopy(model)
model.train()
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
for _ in range(epochs):
for x, y in dataloader:
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
return model.state_dict()def fedavg(state_dicts):
avg_state = copy.deepcopy(state_dicts[0])
for key in avg_state.keys():
avg_state[key] = sum(sd[key] for sd in state_dicts) / len(state_dicts)
return avg_state
# 构造 3 个客户端数据
clients = []
for seed in [1, 2, 3]:
torch.manual_seed(seed)
x = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(x, y)
clients.append(DataLoader(dataset, batch_size=16, shuffle=True))
server_model = SimpleNet()
for round_id in range(5):
local_updates = []
for client_loader in clients:
local_state = local_train(server_model, client_loader, epochs=1, lr=0.05)
local_updates.append(local_state)
server_model.load_state_dict(fedavg(local_updates))
print(f"round {round_id + 1} finished")FedAvg 基本流程:
1. 服务端下发全局模型
2. 客户端在本地训练若干轮
3. 上传模型参数/梯度
4. 服务端聚合得到新全局模型
5. 进入下一轮训练加权 FedAvg 与 Non-IID 处理
def weighted_fedavg(state_dicts, client_weights=None):
"""加权联邦平均
标准 FedAvg 对所有客户端等权平均。
加权 FedAvg 根据客户端的数据量或贡献分配权重。
权重设计策略:
1. 按数据量加权:数据量大的客户端贡献更大
2. 按损失加权:损失大的客户端(困难数据)权重更大
3. 按性能加权:性能好的客户端权重更大
4. 自适应权重:根据历史表现动态调整
"""
if client_weights is None:
client_weights = [1.0 / len(state_dicts)] * len(state_dicts)
else:
total = sum(client_weights)
client_weights = [w / total for w in client_weights]
avg_state = copy.deepcopy(state_dicts[0])
for key in avg_state.keys():
avg_state[key] = sum(
w * sd[key] for w, sd in zip(client_weights, state_dicts)
)
return avg_state
# 示例:按数据量加权
client_data_sizes = [100, 200, 50]
updates = [{"param": torch.randn(10, 10)} for _ in range(3)]
result = weighted_fedavg(updates, client_data_sizes)
print("加权联邦平均已完成")FedProx:处理异质性
def local_train_fedprox(model, dataloader, global_model_state, epochs=1, lr=0.01, mu=0.01):
"""FedProx:通过近端项约束本地更新
FedProx 在本地损失函数中添加了近端项(Proximal Term):
L_local = L_task + (mu/2) * ||w - w_global||^2
mu 控制本地模型与全局模型的偏离程度:
- mu = 0: 退化为标准 FedAvg
- mu 越大:本地更新越接近全局模型,对 Non-IID 更鲁棒
近端项的效果:
- 限制本地模型偏离全局模型太远
- 减轻 Client Drift 问题
- 在高度异构的数据分布下提高收敛性
"""
model = copy.deepcopy(model)
model.train()
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
for _ in range(epochs):
for x, y in dataloader:
optimizer.zero_grad()
loss = criterion(model(x), y)
# 添加近端项
proximal_term = 0.0
for name, param in model.named_parameters():
proximal_term += ((param - global_model_state[name]) ** 2).sum()
loss += (mu / 2) * proximal_term
loss.backward()
optimizer.step()
return model.state_dict()
print("FedProx 本地训练已定义")差分隐私与安全聚合思路
# 简化示意:对梯度裁剪并加噪声
import math
def clip_and_add_noise(tensor, max_norm=1.0, noise_std=0.05):
norm = tensor.norm()
if norm > max_norm:
tensor = tensor * (max_norm / (norm + 1e-8))
noise = torch.randn_like(tensor) * noise_std
return tensor + noise# 对本地模型更新做简单处理(示意)
def sanitize_state_dict(state_dict):
safe_state = {}
for key, value in state_dict.items():
safe_state[key] = clip_and_add_noise(value)
return safe_state
safe_updates = [sanitize_state_dict(local_train(server_model, client_loader)) for client_loader in clients]
server_model.load_state_dict(fedavg(safe_updates))注意:
- "数据不上传"不等于没有隐私风险
- 模型更新本身也可能泄漏信息
- 真正生产可用的联邦学习通常还要配合安全聚合、差分隐私、可信执行环境等机制差分隐私详解
def explain_differential_privacy():
"""差分隐私(Differential Privacy)详解
核心思想:在模型更新中添加精心设计的噪声,使得攻击者无法确定
任何单个数据点是否参与了训练。
数学定义:(epsilon, delta)-差分隐私
对于任意两个相邻数据集 D 和 D'(仅差一条记录),
机制 M 满足:
Pr[M(D) in S] <= exp(epsilon) * Pr[M(D' in S] + delta
epsilon (隐私预算):
- epsilon 越小,隐私保护越强,但噪声越大,模型精度越低
- 典型值:epsilon = 1-10(宽松),epsilon < 1(严格)
实现步骤:
1. 梯度裁剪:限制单个样本对梯度的影响
clip(g, C) = g * min(1, C / ||g||)
2. 噪声添加:添加高斯或拉普拉斯噪声
g_noisy = g_clipped + N(0, sigma^2)
3. 隐私预算跟踪:记录累计的 epsilon 消耗
噪声标度 sigma 与 epsilon 的关系(高斯机制):
sigma = C * sqrt(2 * ln(1.25 / delta)) / epsilon
隐私预算管理:
- 线性组合:epsilon_total = sum(epsilon_i)
- 高级组合定理:epsilon_total <= sqrt(T) * epsilon
- Rényi DP / zCDP:更精确的预算计算
"""
print("差分隐私关键参数:")
print(" epsilon (隐私预算): 越小越安全,典型值 1-10")
print(" delta (失败概率): 典型值 1e-5")
print(" clip_norm (裁剪阈值): 限制单样本影响")
print(" noise_multiplier (噪声倍数): sigma / clip_norm")
# 模拟不同隐私预算下的精度影响
print("\n隐私预算 vs 模型精度(示意):")
for eps in [0.5, 1.0, 3.0, 5.0, 10.0, float('inf')]:
noise = 1.0 / max(eps, 0.1) # 噪声与 epsilon 成反比
accuracy = max(0.5, 0.95 - noise * 0.02) # 示意性精度下降
print(f" epsilon={eps:>6.1f}: 噪声水平={noise:.3f}, 精度约={accuracy:.2%}")
explain_differential_privacy()安全聚合
def explain_secure_aggregation():
"""安全聚合(Secure Aggregation)
目标:服务端只能看到聚合后的模型参数,无法看到任何单个客户端的更新。
方法:
1. 基于秘密共享(Secret Sharing):
- 每个客户端将自己的更新拆分为多份
- 通过多轮通信重构聚合结果
- 服务端始终看不到单个客户端的更新
2. 基于同态加密(Homomorphic Encryption):
- 客户端用公钥加密更新
- 服务端在密文上执行聚合操作
- 只有授权方可以解密最终结果
3. 基于可信执行环境(TEE):
- 使用 SGX 等硬件安全区域
- 在安全区域内执行聚合计算
- 保证计算过程的机密性
注意:安全聚合只保护"传输中"的隐私,
聚合后的全局模型仍然可能被反演攻击推断出训练数据。
"""
print("安全聚合方案对比:")
print(" 秘密共享: 通信开销大,但无需可信第三方")
print(" 同态加密: 计算开销大,安全性高")
print(" TEE: 需要特殊硬件,性能较好")
print("\n推荐: 秘密共享 + 差分隐私的组合方案")
explain_secure_aggregation()横向联邦、纵向联邦与实际挑战
# 横向联邦:不同机构特征空间相同,样本不同
hospital_a = {"features": ["age", "bp", "glucose"], "samples": 10000}
hospital_b = {"features": ["age", "bp", "glucose"], "samples": 8000}
# 纵向联邦:样本主体相同,特征分布在不同机构
bank = {"features": ["income", "loan_count"], "user_ids": [1, 2, 3]}
ecommerce = {"features": ["order_count", "avg_amount"], "user_ids": [1, 2, 3]}
print(hospital_a, hospital_b, bank, ecommerce)联邦学习主要挑战:
1. Non-IID:各客户端数据分布差异大
2. Client Drift:本地更新方向不同,收敛困难
3. 通信瓶颈:模型参数大、同步轮次多
4. 客户端掉线:终端设备不稳定
5. 安全风险:投毒攻击、反演攻击、模型窃取# 实验建议记录的指标
experiment_log = {
"global_rounds": 20,
"client_count": 50,
"participation_rate": 0.2,
"avg_local_loss": 0.38,
"global_accuracy": 0.87,
"comm_cost_mb": 512,
"privacy_budget_eps": 3.0
}
print(experiment_log)通信效率优化
def communication_optimization_strategies():
"""联邦学习通信优化策略
通信是联邦学习的主要瓶颈:
- 模型参数通常为 MB 级别
- 训练需要数十到数百轮通信
- 客户端可能在弱网环境(移动端)
优化方法:
1. 梯度压缩:只传输重要信息
- Top-K:只传输最大的 K 个梯度值
- 量化:将浮点参数量化为低精度
- 稀疏化:只传输非零梯度
2. 模型压缩:减少传输量
- 知识蒸馏:用小模型代替大模型
- 剪枝:去除不重要的参数
- 低秩分解:用低秩矩阵近似权重矩阵
3. 通信调度:减少通信轮次
- FedBuff:异步聚合,减少等待
- Local Steps:增加本地训练步数
- Active Client Selection:选择贡献大的客户端
"""
print("通信优化策略:")
print(" 1. Top-K 稀疏化: 只传输最重要的 1-10% 梯度")
print(" 2. 量化: FP32 -> INT8,减少 4 倍通信量")
print(" 3. 增加本地步数: 从 1 epoch 增加到 5-10 epochs")
print(" 4. 异步聚合: FedBuff、HieraAsyn")
print(" 5. 模型压缩: 剪枝 + 蒸馏后传输")
communication_optimization_strategies()梯度压缩实现
import torch
def top_k_compress(tensor: torch.Tensor, k_ratio: float = 0.1):
"""Top-K 梯度压缩
只保留绝对值最大的 K% 的梯度,其余置零。
这可以大幅减少通信量,同时保留最重要的更新信息。
Args:
tensor: 模型梯度或参数更新
k_ratio: 保留比例(0.1 = 保留 10%)
Returns:
压缩后的张量和索引
"""
k = max(1, int(tensor.numel() * k_ratio))
values, indices = torch.topk(tensor.abs().flatten(), k)
compressed = torch.zeros_like(tensor).flatten()
compressed[indices] = tensor.flatten()[indices]
return compressed.reshape(tensor.shape), indices
def quantize_tensor(tensor: torch.Tensor, bits: int = 8):
"""量化压缩
将 FP32 张量量化为低精度表示。
常用量化位数:8-bit(减少 4 倍)、4-bit(减少 8 倍)。
Args:
tensor: 待量化的张量
bits: 量化位数
Returns:
量化后的张量和缩放参数
"""
min_val = tensor.min()
max_val = tensor.max()
scale = (max_val - min_val) / (2 ** bits - 1)
# 量化
quantized = ((tensor - min_val) / scale).round().to(torch.int8 if bits <= 8 else torch.int16)
return quantized, {'min': min_val, 'scale': scale}
# 示例
grad = torch.randn(1000, 1000) # 模拟梯度(约 4MB)
compressed, _ = top_k_compress(grad, 0.05) # 压缩到 5%
non_zero = (compressed != 0).sum().item()
print(f"原始: {grad.numel():,} 元素, 压缩后: {non_zero:,} 非零元素 ({100*non_zero/grad.numel():.1f}%)")客户端选择策略
import random
from typing import List, Callable
def client_selection_strategies():
"""客户端选择策略
每轮训练不需要所有客户端参与,选择策略影响训练效率和模型质量。
常用策略:
1. 随机选择:每轮随机选取 K% 的客户端
2. 基于损失选择:优先选择损失大的客户端(困难样本)
3. 基于数据量选择:优先选择数据量大的客户端
4. 基于更新幅度选择:优先选择梯度更新大的客户端
5. 资源感知选择:优先选择网络好、算力强的客户端
"""
class ClientSelector:
"""客户端选择器"""
def __init__(self, total_clients: int):
self.total = total_clients
self.history = {} # client_id -> 历史信息
def random_select(self, fraction: float = 0.2) -> List[int]:
"""随机选择"""
k = max(1, int(self.total * fraction))
return random.sample(range(self.total), k)
def loss_based_select(self, client_losses: dict, fraction: float = 0.2) -> List[int]:
"""基于损失选择 — 优先选择损失大的客户端"""
k = max(1, int(self.total * fraction))
sorted_clients = sorted(client_losses.items(), key=lambda x: x[1], reverse=True)
return [c for c, _ in sorted_clients[:k]]
def resource_aware_select(self, client_resources: dict, fraction: float = 0.2) -> List[int]:
"""资源感知选择 — 优先选择带宽好、延迟低的客户端"""
k = max(1, int(self.total * fraction))
# 按综合评分排序(带宽 / 延迟)
scored = {
cid: info['bandwidth'] / max(1, info['latency'])
for cid, info in client_resources.items()
}
sorted_clients = sorted(scored.items(), key=lambda x: x[1], reverse=True)
return [c for c, _ in sorted_clients[:k]]
selector = ClientSelector(100)
print(f"随机选择: {selector.random_select(0.1)}")联邦学习实验记录
class FLExperimentLogger:
"""联邦学习实验日志记录器
记录训练过程中的关键指标,用于分析和调优。
"""
def __init__(self, experiment_name: str):
self.name = experiment_name
self.rounds = []
def log_round(self, round_id: int, metrics: dict):
"""记录每轮训练指标"""
entry = {
'round': round_id,
'global_accuracy': metrics.get('accuracy'),
'global_loss': metrics.get('loss'),
'client_count': metrics.get('client_count'),
'participation_rate': metrics.get('participation_rate'),
'communication_mb': metrics.get('communication_mb'),
'round_time_seconds': metrics.get('round_time'),
}
self.rounds.append(entry)
# 打印进度
print(f"[Round {round_id}] "
f"Accuracy={entry['global_accuracy']:.4f}, "
f"Loss={entry['global_loss']:.4f}, "
f"Clients={entry['client_count']}, "
f"Time={entry['round_time_seconds']:.1f}s")
def summary(self):
"""输出实验摘要"""
if not self.rounds:
return
best = max(self.rounds, key=lambda x: x['global_accuracy'] or 0)
total_comm = sum(r['communication_mb'] or 0 for r in self.rounds)
total_time = sum(r['round_time_seconds'] or 0 for r in self.rounds)
print(f"\n===== 实验 {self.name} 摘要 =====")
print(f"总轮次: {len(self.rounds)}")
print(f"最佳准确率: {best['global_accuracy']:.4f} (Round {best['round']})")
print(f"总通信量: {total_comm:.1f} MB")
print(f"总训练时间: {total_time:.1f} 秒")
# 使用
logger = FLExperimentLogger("FedAvg_CIFAR10_50clients")
for i in range(1, 6):
logger.log_round(i, {
'accuracy': 0.6 + i * 0.05,
'loss': 1.2 - i * 0.15,
'client_count': 10,
'participation_rate': 0.2,
'communication_mb': 128,
'round_time': 45.0,
})
logger.summary()优点
缺点
总结
联邦学习解决的是"数据无法集中但又想联合训练"的现实问题,而不是隐私与安全的万能方案。真正落地时,必须同时考虑数据异质性、通信成本、聚合策略和安全边界,否则很容易出现效果不收敛、成本过高或安全性不足的问题。
在实际项目中,建议首先明确联邦学习的必要性:是否真的存在无法集中数据的硬约束?如果数据可以集中,集中式训练通常效果更好、成本更低。如果必须使用联邦学习,建议从小规模 PoC 开始,验证收敛性、通信成本和隐私方案,再逐步扩大规模。
关键知识点
- 数据不出域不代表完全没有隐私泄漏风险。
- FedAvg 是最经典的聚合方法,但不一定适合所有 Non-IID 场景。
- 横向联邦和纵向联邦适用的组织结构不同。
- 联邦学习是算法、系统、隐私工程三者共同作用的结果。
- 差分隐私的 epsilon 越小,隐私保护越强但模型精度越低。
- 通信优化(压缩、量化、异步)是大规模联邦学习的关键。
项目落地视角
- 医疗机构可在不共享患者原始数据下协作训练模型。
- 金融机构可在严格合规条件下探索联合风控建模。
- 终端侧输入法、推荐系统可做本地学习与参数上报。
- 多方协作前必须先统一特征口径、标签定义和模型版本。
联邦学习框架对比
def compare_fl_frameworks():
"""联邦学习框架对比"""
frameworks = {
"Flower (flwr)": {
"语言": "Python",
"特点": "框架无关,支持 PyTorch/TF/任意框架",
"适用": "研究和小规模部署",
"成熟度": "高",
},
"FedML": {
"语言": "Python",
"特点": "跨平台(手机/边缘/云),支持横向/纵向",
"适用": "大规模边缘部署",
"成熟度": "高",
},
"TensorFlow Federated": {
"语言": "Python",
"特点": "Google 官方,与 TF 生态集成",
"适用": "TF 生态用户",
"成熟度": "中",
},
"PySyft": {
"语言": "Python",
"特点": "隐私计算工具包,支持 SMPC",
"适用": "安全计算研究",
"成熟度": "中",
},
}
print("联邦学习框架对比:")
for name, info in frameworks.items():
print(f" {name}: {info['特点']}")
compare_fl_frameworks()常见误区
- 认为"数据不上传"就已经足够安全。
- 忽略客户端数据分布差异,直接套用集中式训练思路。
- 不评估通信成本,导致实验很难扩展到真实规模。
- 没有版本治理,客户端特征和标签定义长期漂移。
- 过度依赖差分隐私导致模型精度严重下降。
- 忽略投毒攻击(恶意客户端上传有害更新)的风险。
进阶路线
- 学习 FedProx、SCAFFOLD 等处理 Non-IID 的方法。
- 深入研究安全聚合、差分隐私和对抗攻击防护。
- 使用 Flower、FedML、TensorFlow Federated 等框架做工程落地。
- 将联邦学习与边缘部署、模型压缩结合起来。
- 学习联邦学习中的公平性问题和激励机制设计。
适用场景
- 医疗、金融、政务等强隐私行业。
- 多机构合作但原始数据不可直接共享的场景。
- 大量终端设备参与训练的边缘学习。
- 需要兼顾模型效果与数据合规的联合建模项目。
落地建议
- 先验证是否真的存在"无法集中数据但又必须联合训练"的约束。
- 在小规模 PoC 中先评估收敛速度、通信成本和隐私方案。
- 定义统一特征协议、标签规范和客户端版本管理机制。
- 为联邦训练建立观测指标:参与率、收敛速度、通信成本、安全事件。
排错清单
- 检查客户端是否使用一致的特征与标签定义。
- 检查全局模型不收敛是因为 Non-IID、采样率还是学习率问题。
- 检查通信失败、掉线客户端和聚合异常是否被正确处理。
- 检查隐私增强方案是否显著影响模型效果或训练稳定性。
复盘问题
- 你做联邦学习,是因为业务真的需要,还是因为概念看起来先进?
- 当前最大瓶颈是隐私、收敛、通信还是工程实现?
- 如果客户端规模扩大 10 倍,现有方案还能跑得动吗?
- 你的联邦训练结果是否真正优于各方各自单独训练?
