AI 模型版本管理与实验追踪
大约 16 分钟约 4707 字
AI 模型版本管理与实验追踪
简介
AI 模型版本管理与实验追踪是 MLOps 的核心环节。在机器学习项目中,数据科学家每天可能进行数十次实验,每次实验涉及不同的超参数、数据版本、模型架构。如果没有系统化的管理手段,很快就会陷入"哪个模型最好"、"这个结果是怎么来的"的混乱中。
模型版本管理解决的核心问题包括:实验可复现性(Reproducibility)、模型血缘追踪(Lineage)、版本对比(Comparison)、部署管控(Governance)。这四个维度构成了完整的模型生命周期管理能力。
传统软件开发的版本控制(Git)无法满足 ML 的需求,因为 ML 的"代码"不仅包括程序代码,还包括数据、模型权重、超参数、环境配置等多种制品(Artifact)。因此需要专门的 ML 实验追踪工具和模型注册中心。
特点
MLflow 实验追踪
基础实验记录
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np
# 设置追踪服务器
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("客户流失预测")
def train_and_log_model(
n_estimators: int = 100,
max_depth: int = 10,
min_samples_split: int = 2,
):
"""训练模型并记录实验"""
# 生成模拟数据
np.random.seed(42)
X = np.random.rand(1000, 20)
y = (X[:, 0] + X[:, 1] > 1.0).astype(int)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 开始 MLflow 运行
with mlflow.start_run(run_name=f"rf_n{n_estimators}_d{max_depth}"):
# 记录超参数
mlflow.log_params({
"n_estimators": n_estimators,
"max_depth": max_depth,
"min_samples_split": min_samples_split,
"model_type": "RandomForest",
})
# 训练模型
model = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
min_samples_split=min_samples_split,
random_state=42,
)
model.fit(X_train, y_train)
# 预测和评估
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"f1_score": f1_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred),
"recall": recall_score(y_test, y_pred),
}
# 记录指标
mlflow.log_metrics(metrics)
# 设置标签
mlflow.set_tags({
"team": "data-science",
"version": "1.0.0",
"purpose": "binary_classification",
})
# 记录模型
mlflow.sklearn.log_model(
model,
"model",
registered_model_name="customer_churn_model",
)
print(f"Run ID: {mlflow.active_run().info.run_id}")
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1 Score: {metrics['f1_score']:.4f}")
# 运行多组实验
experiments = [
{"n_estimators": 50, "max_depth": 5, "min_samples_split": 2},
{"n_estimators": 100, "max_depth": 10, "min_samples_split": 2},
{"n_estimators": 200, "max_depth": 15, "min_samples_split": 5},
]
for params in experiments:
train_and_log_model(**params)实验对比与查询
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient()
def compare_experiments(experiment_name: str, metric_key: str = "accuracy"):
"""对比实验结果"""
experiment = client.get_experiment_by_name(experiment_name)
if not experiment:
print(f"实验 '{experiment_name}' 不存在")
return
# 搜索所有运行
runs = client.search_runs(
experiment_ids=[experiment.experiment_id],
order_by=[f"metrics.{metric_key} DESC"],
max_results=20,
)
print(f"{'Run Name':<30} {'Accuracy':<12} {'F1':<12} {'Status':<10}")
print("-" * 70)
for run in runs:
name = run.data.tags.get("mlflow.runName", run.info.run_id[:8])
accuracy = run.data.metrics.get("accuracy", 0)
f1 = run.data.metrics.get("f1_score", 0)
status = run.info.status
print(f"{name:<30} {accuracy:<12.4f} {f1:<12.4f} {status:<10}")
def get_best_run(experiment_name: str, metric_key: str = "f1_score"):
"""获取最优实验"""
experiment = client.get_experiment_by_name(experiment_name)
runs = client.search_runs(
experiment_ids=[experiment.experiment_id],
order_by=[f"metrics.{metric_key} DESC"],
max_results=1,
)
if runs:
best = runs[0]
print(f"最佳实验: {best.data.tags.get('mlflow.runName', 'N/A')}")
print(f"Run ID: {best.info.run_id}")
print(f"指标: {best.data.metrics}")
print(f"参数: {best.data.params}")
return best
return None
# 使用示例
compare_experiments("客户流失预测", "f1_score")
best = get_best_run("客户流失预测", "f1_score")MLflow 模型注册中心
模型版本管理
from mlflow.tracking import MlflowClient
client = MlflowClient()
class ModelRegistry:
"""模型注册中心管理"""
def __init__(self):
self.client = MlflowClient()
def list_model_versions(self, model_name: str):
"""列出模型所有版本"""
versions = self.client.search_model_versions(
f"name='{model_name}'"
)
print(f"模型 '{model_name}' 的所有版本:")
print(f"{'版本':<8} {'状态':<15} {'阶段':<15} {'创建时间':<20}")
print("-" * 60)
for v in versions:
print(
f"v{v.version:<6} {v.status:<15} "
f"{v.current_stage:<15} {v.creation_timestamp:<20}"
)
def transition_stage(self, model_name: str, version: str, stage: str):
"""转换模型阶段
Stages: None -> Staging -> Production -> Archived
"""
result = self.client.transition_model_version_stage(
name=model_name,
version=version,
stage=stage,
)
print(f"模型 {model_name} v{version} 已转换到 {stage} 阶段")
return result
def compare_versions(
self, model_name: str, version_a: str, version_b: str
):
"""对比两个模型版本"""
va = self.client.get_model_version(model_name, version_a)
vb = self.client.get_model_version(model_name, version_b)
# 获取对应的 Run 信息
run_a = self.client.get_run(va.run_id)
run_b = self.client.get_run(vb.run_id)
metrics_a = run_a.data.metrics
metrics_b = run_b.data.metrics
print(f"{'指标':<20} {'版本 ' + version_a:<15} {'版本 ' + version_b:<15} {'差异':<10}")
print("-" * 60)
all_metrics = set(list(metrics_a.keys()) + list(metrics_b.keys()))
for metric in sorted(all_metrics):
val_a = metrics_a.get(metric, 0)
val_b = metrics_b.get(metric, 0)
diff = val_b - val_a
sign = "+" if diff > 0 else ""
print(f"{metric:<20} {val_a:<15.4f} {val_b:<15.4f} {sign}{diff:<10.4f}")
def add_model_description(
self, model_name: str, version: str, description: str
):
"""添加模型描述和变更记录"""
self.client.update_model_version(
name=model_name,
version=version,
description=description,
)
def promote_to_production(
self,
model_name: str,
version: str,
approval_note: str = "",
):
"""将模型提升到生产环境
包含审批流程:
1. 验证 Staging 阶段的指标
2. 确认 A/B 测试结果
3. 转换到 Production
"""
model_version = self.client.get_model_version(model_name, version)
# 检查是否在 Staging
if model_version.current_stage != "Staging":
print(f"警告: 模型当前在 {model_version.current_stage} 阶段,建议先进入 Staging")
return
# 转换到 Production
self.transition_stage(model_name, version, "Production")
# 添加描述
self.add_model_description(
model_name, version,
f"提升到生产环境。审批备注: {approval_note}"
)
print(f"模型 {model_name} v{version} 已成功提升到生产环境")
# 使用示例
registry = ModelRegistry()
registry.list_model_versions("customer_churn_model")
registry.compare_versions("customer_churn_model", "1", "2")
registry.promote_to_production("customer_churn_model", "2", "F1提升5%,A/B测试通过")模型血缘追踪
import mlflow
from dataclasses import dataclass
from typing import Optional
import json
@dataclass
class DataLineage:
"""数据血缘记录"""
dataset_name: str
dataset_version: str
source_path: str
row_count: int
feature_columns: list
preprocessing_steps: list
hash: str
@dataclass
class ModelLineage:
"""模型血缘追踪"""
model_name: str
version: str
run_id: str
data_lineage: DataLineage
code_version: str # Git commit hash
hyperparameters: dict
training_duration_sec: float
metrics: dict
parent_model: Optional[str] = None # 蒸馏/微调的父模型
class LineageTracker:
"""血缘追踪器"""
def __init__(self):
self.lineages = {}
def track_full_lineage(
self,
model_name: str,
run_id: str,
data_info: dict,
code_version: str,
):
"""追踪完整的模型血缘"""
data_lineage = DataLineage(
dataset_name=data_info["name"],
dataset_version=data_info["version"],
source_path=data_info["path"],
row_count=data_info["row_count"],
feature_columns=data_info["features"],
preprocessing_steps=data_info.get("steps", []),
hash=data_info.get("hash", ""),
)
# 从 MLflow 获取 Run 信息
client = mlflow.tracking.MlflowClient()
run = client.get_run(run_id)
lineage = ModelLineage(
model_name=model_name,
version="1",
run_id=run_id,
data_lineage=data_lineage,
code_version=code_version,
hyperparameters=run.data.params,
training_duration_sec=(
run.info.end_time - run.info.start_time
) / 1000 if run.info.end_time else 0,
metrics=run.data.metrics,
)
self.lineages[f"{model_name}:{run_id}"] = lineage
return lineage
def get_lineage(self, model_name: str, run_id: str) -> dict:
"""获取模型血缘"""
key = f"{model_name}:{run_id}"
lineage = self.lineages.get(key)
if not lineage:
return {"error": "血缘记录不存在"}
return {
"model": {
"name": lineage.model_name,
"version": lineage.version,
"run_id": lineage.run_id,
},
"data": {
"dataset": lineage.data_lineage.dataset_name,
"version": lineage.data_lineage.dataset_version,
"source": lineage.data_lineage.source_path,
"rows": lineage.data_lineage.row_count,
"features": lineage.data_lineage.feature_columns,
"preprocessing": lineage.data_lineage.preprocessing_steps,
},
"code": {
"git_commit": lineage.code_version,
},
"metrics": lineage.metrics,
"parent_model": lineage.parent_model,
}
def print_lineage_tree(self, model_name: str):
"""打印血缘树"""
print(f"模型血缘树: {model_name}")
print("=" * 50)
for key, lineage in self.lineages.items():
if lineage.model_name == model_name:
print(f" [{lineage.model_name} v{lineage.version}]")
print(f" 数据: {lineage.data_lineage.dataset_name} "
f"v{lineage.data_lineage.dataset_version}")
print(f" 代码: {lineage.code_version[:8]}")
print(f" 指标: {lineage.metrics}")
if lineage.parent_model:
print(f" 父模型: {lineage.parent_model}")
print()A/B 测试模型
import random
import time
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ABTestConfig:
"""A/B 测试配置"""
test_name: str
model_a_name: str
model_a_version: str
model_b_name: str
model_b_version: str
traffic_split: float = 0.5 # B 组流量比例
min_samples: int = 1000 # 最少样本数
significance_level: float = 0.05
metrics_to_compare: list = field(
default_factory=lambda: ["accuracy", "latency_ms", "user_satisfaction"]
)
@dataclass
class PredictionResult:
model_name: str
model_version: str
prediction: str
confidence: float
latency_ms: float
ground_truth: Optional[str] = None
user_feedback: Optional[int] = None # 1-5 评分
class ModelABTester:
"""模型 A/B 测试框架"""
def __init__(self):
self.tests: dict[str, ABTestConfig] = {}
self.results: dict[str, list[PredictionResult]] = {}
def create_test(self, config: ABTestConfig):
"""创建 A/B 测试"""
self.tests[config.test_name] = config
self.results[config.test_name] = {"A": [], "B": []}
print(f"A/B 测试 '{config.test_name}' 已创建")
print(f" 模型 A: {config.model_a_name} v{config.model_a_version}")
print(f" 模型 B: {config.model_b_name} v{config.model_b_version}")
print(f" 流量分配: {1-config.traffic_split:.0%} / {config.traffic_split:.0%}")
def route_request(self, test_name: str) -> str:
"""路由请求到对应的模型组"""
config = self.tests[test_name]
if random.random() < config.traffic_split:
return "B"
return "A"
def record_result(
self, test_name: str, group: str, result: PredictionResult
):
"""记录预测结果"""
self.results[test_name][group].append(result)
def analyze_test(self, test_name: str) -> dict:
"""分析 A/B 测试结果"""
config = self.tests[test_name]
results_a = self.results[test_name]["A"]
results_b = self.results[test_name]["B"]
if len(results_a) < 30 or len(results_b) < 30:
return {"status": "样本不足", "samples_a": len(results_a), "samples_b": len(results_b)}
# 计算各组指标
def calc_metrics(results: list):
if not results:
return {}
latencies = [r.latency_ms for r in results]
confidences = [r.confidence for r in results]
accuracy = 0
if any(r.ground_truth for r in results):
correct = sum(
1 for r in results
if r.ground_truth and r.prediction == r.ground_truth
)
total = sum(1 for r in results if r.ground_truth)
accuracy = correct / total if total > 0 else 0
avg_satisfaction = 0
if any(r.user_feedback for r in results):
feedbacks = [r.user_feedback for r in results if r.user_feedback]
avg_satisfaction = sum(feedbacks) / len(feedbacks) if feedbacks else 0
return {
"count": len(results),
"avg_latency_ms": sum(latencies) / len(latencies),
"p95_latency_ms": sorted(latencies)[int(len(latencies) * 0.95)],
"avg_confidence": sum(confidences) / len(confidences),
"accuracy": accuracy,
"avg_satisfaction": avg_satisfaction,
}
metrics_a = calc_metrics(results_a)
metrics_b = calc_metrics(results_b)
return {
"test_name": test_name,
"model_a": f"{config.model_a_name} v{config.model_a_version}",
"model_b": f"{config.model_b_name} v{config.model_b_version}",
"group_a": metrics_a,
"group_b": metrics_b,
"recommendation": self._recommend(metrics_a, metrics_b, config),
}
def _recommend(self, metrics_a: dict, metrics_b: dict, config: ABTestConfig) -> str:
"""给出推荐结论"""
if not metrics_a or not metrics_b:
return "数据不足,继续测试"
# 综合评分
score_a = metrics_a.get("accuracy", 0) * 0.4 + metrics_a.get("avg_satisfaction", 0) / 5 * 0.4 + (1 - metrics_a.get("avg_latency_ms", 1000) / 5000) * 0.2
score_b = metrics_b.get("accuracy", 0) * 0.4 + metrics_b.get("avg_satisfaction", 0) / 5 * 0.4 + (1 - metrics_b.get("avg_latency_ms", 1000) / 5000) * 0.2
if score_b > score_a * 1.02:
return f"推荐模型 B (综合分 {score_b:.3f} > {score_a:.3f})"
elif score_a > score_b * 1.02:
return f"保持模型 A (综合分 {score_a:.3f} > {score_b:.3f})"
else:
return "两个模型表现接近,建议延长测试周期"
# 使用示例
tester = ModelABTester()
config = ABTestConfig(
test_name="churn_model_v2_test",
model_a_name="customer_churn_model",
model_a_version="1",
model_b_name="customer_churn_model",
model_b_version="2",
traffic_split=0.3,
)
tester.create_test(config)金丝雀部署
import time
from dataclasses import dataclass
from typing import Optional
@dataclass
class CanaryConfig:
"""金丝雀部署配置"""
model_name: str
stable_version: str
canary_version: str
canary_percentage: float = 5.0 # 初始金丝雀流量百分比
max_percentage: float = 50.0 # 最大金丝雀流量百分比
step_percentage: float = 5.0 # 每步增加的百分比
error_threshold: float = 0.05 # 错误率阈值
latency_threshold_ms: float = 2000 # 延迟阈值
evaluation_interval_sec: int = 300 # 评估间隔(秒)
class CanaryDeployer:
"""模型金丝雀部署"""
def __init__(self):
self.deployments: dict[str, CanaryConfig] = {}
self.metrics: dict[str, dict] = {}
def start_canary(self, config: CanaryConfig):
"""启动金丝雀部署"""
self.deployments[config.model_name] = config
self.metrics[config.model_name] = {
"stable_errors": 0,
"stable_total": 0,
"canary_errors": 0,
"canary_total": 0,
"current_percentage": config.canary_percentage,
"status": "running",
}
print(f"金丝雀部署已启动: {config.model_name}")
print(f" 稳定版本: v{config.stable_version}")
print(f" 金丝雀版本: v{config.canary_version}")
print(f" 初始流量: {config.canary_percentage}%")
def route(self, model_name: str) -> str:
"""路由请求"""
config = self.deployments.get(model_name)
if not config:
return "stable"
metrics = self.metrics[model_name]
if metrics["status"] != "running":
return "stable"
import random
if random.random() * 100 < metrics["current_percentage"]:
return "canary"
return "stable"
def record_prediction(
self, model_name: str, version: str, success: bool, latency_ms: float
):
"""记录预测结果"""
metrics = self.metrics.get(model_name)
if not metrics:
return
config = self.deployments[model_name]
if version == config.canary_version:
metrics["canary_total"] += 1
if not success:
metrics["canary_errors"] += 1
else:
metrics["stable_total"] += 1
if not success:
metrics["stable_errors"] += 1
def evaluate(self, model_name: str) -> dict:
"""评估金丝雀部署"""
config = self.deployments.get(model_name)
metrics = self.metrics.get(model_name)
if not config or not metrics:
return {"error": "部署不存在"}
canary_error_rate = (
metrics["canary_errors"] / max(metrics["canary_total"], 1)
)
stable_error_rate = (
metrics["stable_errors"] / max(metrics["stable_total"], 1)
)
result = {
"model": model_name,
"canary_error_rate": f"{canary_error_rate:.2%}",
"stable_error_rate": f"{stable_error_rate:.2%}",
"canary_samples": metrics["canary_total"],
"current_percentage": f"{metrics['current_percentage']:.1f}%",
}
# 决策:是否继续推广
if canary_error_rate > config.error_threshold:
result["action"] = "rollback"
result["reason"] = f"金丝雀错误率 {canary_error_rate:.2%} 超过阈值 {config.error_threshold:.2%}"
metrics["status"] = "rolled_back"
elif metrics["canary_total"] >= 100:
if canary_error_rate <= stable_error_rate:
new_pct = min(
metrics["current_percentage"] + config.step_percentage,
config.max_percentage,
)
metrics["current_percentage"] = new_pct
result["action"] = "promote"
result["new_percentage"] = f"{new_pct:.1f}%"
if new_pct >= config.max_percentage:
result["action"] = "complete"
metrics["status"] = "completed"
else:
result["action"] = "hold"
result["reason"] = "金丝雀错误率高于稳定版本,暂停推广"
return result
def rollback(self, model_name: str, reason: str = ""):
"""回滚金丝雀部署"""
metrics = self.metrics.get(model_name)
if metrics:
metrics["status"] = "rolled_back"
metrics["current_percentage"] = 0
print(f"金丝雀部署已回滚: {model_name}. 原因: {reason}")DVC 数据版本管理
import json
import hashlib
import os
from pathlib import Path
class SimpleDVC:
"""简化版数据版本管理
DVC (Data Version Control) 是 Git 的数据版本管理扩展。
核心功能:
1. 大文件版本管理(数据集、模型文件)
2. 数据流水线定义
3. 与 Git 集成的数据追踪
"""
def __init__(self, repo_root: str = "."):
self.repo_root = Path(repo_root)
self.dvc_dir = self.repo_root / ".dvc"
self.meta_file = self.dvc_dir / "data_registry.json"
def init(self):
"""初始化 DVC 仓库"""
self.dvc_dir.mkdir(exist_ok=True)
if not self.meta_file.exists():
self._save_registry({})
print("DVC 仓库已初始化")
def add(self, file_path: str, description: str = ""):
"""添加数据文件到版本管理"""
full_path = self.repo_root / file_path
if not full_path.exists():
print(f"文件不存在: {file_path}")
return
# 计算文件哈希
file_hash = self._compute_hash(full_path)
file_size = full_path.stat().st_size
registry = self._load_registry()
registry[file_path] = {
"hash": file_hash,
"size": file_size,
"description": description,
"added_at": str(os.path.getmtime(full_path)),
}
self._save_registry(registry)
print(f"已添加: {file_path} (hash: {file_hash[:12]}..., size: {file_size} bytes)")
def status(self):
"""查看数据文件状态"""
registry = self._load_registry()
print(f"{'文件路径':<40} {'大小':<15} {'哈希':<20}")
print("-" * 75)
for path, info in registry.items():
full_path = self.repo_root / path
current_hash = self._compute_hash(full_path) if full_path.exists() else "MISSING"
status = "OK" if current_hash == info["hash"] else "MODIFIED"
print(f"{path:<40} {info['size']:<15} {info['hash'][:12]}... [{status}]")
def _compute_hash(self, file_path: Path) -> str:
"""计算文件 MD5 哈希"""
hasher = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hasher.update(chunk)
return hasher.hexdigest()
def _load_registry(self) -> dict:
if self.meta_file.exists():
with open(self.meta_file, "r") as f:
return json.load(f)
return {}
def _save_registry(self, registry: dict):
with open(self.meta_file, "w") as f:
json.dump(registry, f, indent=2, ensure_ascii=False)
# DVC 命令行使用指南
DVC_COMMANDS = """
# 初始化 DVC
dvc init
# 添加数据文件
dvc add data/train.csv
# Git 追踪 .dvc 文件
git add data/train.csv.dvc .gitignore
git commit -m "添加训练数据集 v1"
# 修改数据后重新添加
dvc add data/train.csv
git add data/train.csv.dvc
git commit -m "更新训练数据集 v2"
# 切换到历史版本的数据
git checkout v1.0
dvc checkout # 恢复对应版本的数据文件
# 定义数据处理流水线
dvc run -n preprocess \\
-d data/train.csv \\
-o data/processed.csv \\
python preprocess.py
# 可视化流水线
dvc dag
"""CI/CD for ML
class MLCICDPipeline:
"""ML 模型 CI/CD 流水线
典型的 ML CI/CD 流程:
1. 数据变更触发 -> 数据验证
2. 代码变更触发 -> 训练流水线
3. 模型注册触发 -> 模型验证
4. 模型审批触发 -> 部署流水线
"""
@staticmethod
def generate_github_actions_workflow() -> str:
"""生成 GitHub Actions 工作流"""
return """
name: ML Training Pipeline
on:
push:
paths:
- 'src/**'
- 'data/**'
- 'configs/**'
workflow_dispatch:
inputs:
experiment_name:
description: '实验名称'
required: true
jobs:
data-validation:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: 验证数据完整性
run: |
python -m src.data_validation \\
--input data/train.csv \\
--schema schemas/train_schema.json
- name: 数据质量检查
run: |
python -m src.data_quality \\
--input data/train.csv \\
--min-rows 1000 \\
--max-null-ratio 0.05
model-training:
needs: data-validation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: 配置 MLflow
run: |
export MLFLOW_TRACKING_URI=${{ secrets.MLFLOW_TRACKING_URI }}
export MLFLOW_EXPERIMENT=${{ github.event.inputs.experiment_name }}
- name: 训练模型
run: |
python -m src.train \\
--config configs/training_config.yaml \\
--experiment $MLFLOW_EXPERIMENT
model-validation:
needs: model-training
runs-on: ubuntu-latest
steps:
- name: 模型性能验证
run: |
python -m src.validate_model \\
--min-accuracy 0.85 \\
--max-latency-ms 500 \\
--fairness-threshold 0.9
- name: 注册模型
run: |
python -m src.register_model \\
--model-name ${{ github.event.repository.name }} \\
--stage Staging
model-deployment:
needs: model-validation
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main'
steps:
- name: 部署到 Staging
run: |
python -m src.deploy \\
--environment staging \\
--canary-percentage 5
- name: 运行集成测试
run: |
python -m tests.integration_test \\
--environment staging
- name: 提升到 Production
run: |
python -m src.promote_model \\
--environment production
"""
@staticmethod
def model_validation_checklist() -> list:
"""模型上线验证清单"""
return [
"准确率 >= 基线模型",
"F1 Score >= 基线模型",
"推理延迟 <= SLA 要求",
"公平性指标满足要求",
"在测试集上无显著过拟合",
"模型大小在可接受范围内",
"A/B 测试结果通过显著性检验",
"回滚方案已准备",
"监控告警已配置",
"审批人已确认",
]模型回滚
class ModelRollback:
"""模型回滚管理"""
def __init__(self):
self.rollback_history = []
def rollback_to_version(
self,
model_name: str,
target_version: str,
reason: str,
):
"""回滚到指定版本"""
rollback_record = {
"model": model_name,
"target_version": target_version,
"reason": reason,
"timestamp": time.time(),
}
self.rollback_history.append(rollback_record)
print(f"正在回滚 {model_name} 到 v{target_version}...")
print(f"回滚原因: {reason}")
# 实际回滚操作
# client.transition_model_version_stage(
# name=model_name,
# version=target_version,
# stage="Production",
# archive_existing_versions=True,
# )
def auto_rollback_on_degradation(
self,
model_name: str,
current_metrics: dict,
baseline_metrics: dict,
degradation_threshold: float = 0.05,
):
"""检测到性能退化时自动回滚"""
for metric_name, baseline_value in baseline_metrics.items():
current_value = current_metrics.get(metric_name, 0)
degradation = baseline_value - current_value
if degradation > degradation_threshold:
print(f"检测到退化: {metric_name} "
f"从 {baseline_value:.4f} 降到 {current_value:.4f}")
print(f"退化幅度: {degradation:.4f} ({degradation/baseline_value:.1%})")
return True
return False
def get_rollback_history(self, model_name: str) -> list:
"""获取回滚历史"""
return [
r for r in self.rollback_history
if r["model"] == model_name
]优点
缺点
性能注意事项
- MLflow Server 性能:大量实验时查询变慢,建议使用 PostgreSQL 作为后端存储
- 模型存储:使用 S3/OSS 等对象存储存放模型文件,MLflow 只存元数据
- 并行实验:多个实验同时运行时注意数据库连接池配置
- DVC 存储:大文件拉取耗时,建议使用共享缓存
- A/B 测试样本量:统计显著性检验需要足够的样本量,至少 1000 次
- 金丝雀步进:步进速度不宜过快,每步至少观察 30 分钟
总结
模型版本管理和实验追踪是 AI 项目工程化的基石。MLflow 提供了完整的实验追踪和模型注册能力,DVC 补充了数据版本管理,两者结合 CI/CD 实现了从训练到部署的自动化流水线。核心原则是一切皆有记录,一切皆可追溯,一切皆可回滚。
关键知识点
- MLflow 四大组件 — Tracking、Projects、Models、Registry
- 模型阶段 — None -> Staging -> Production -> Archived
- 实验追踪要素 — 参数、指标、制品、标签
- 血缘追踪 — 数据 -> 代码 -> 模型 -> 部署的完整链路
- 金丝雀部署 — 逐步增加流量,发现问题及时回滚
- DVC 核心概念 — .dvc 文件、数据缓存、流水线定义
- CI/CD for ML — 数据验证、训练、验证、部署的自动化
常见误区
- 不记录实验:凭记忆管理实验,无法复现最佳结果
- 忽略数据版本:只管代码和模型版本,数据变更导致结果不一致
- 直接上线:跳过 Staging 和 A/B 测试直接部署到生产
- 回滚无预案:出问题才想怎么回滚,导致长时间停机
- 过度自动化:初期就搞复杂流水线,反而增加维护成本
- 模型版本过多:注册太多无用的中间版本,增加管理负担
进阶路线
- 入门:使用 MLflow 记录实验、对比结果
- 进阶:配置模型注册中心、实现 A/B 测试
- 高级:建立 CI/CD 流水线、金丝雀部署、自动回滚
- 专家:完整 MLOps 平台、特征商店、模型监控
适用场景
- 多人协作的 AI 项目
- 频繁迭代实验的模型开发
- 需要审计追溯的金融/医疗 AI
- 多模型版本的 A/B 测试
- 大规模模型部署和运维
落地建议
- 第一步:部署 MLflow Server,团队统一实验记录
- 第二步:建立模型注册中心,规范版本命名
- 第三步:引入 DVC 管理数据集版本
- 第四步:搭建 CI/CD 流水线自动化训练和部署
- 第五步:实现金丝雀部署和自动回滚
- 持续:定期审查实验记录,清理过期模型版本
排错清单
复盘问题
- 本月进行了多少次实验?有多少产生了有效结果?
- 模型从训练到上线的平均周期是多长?
- 发生过几次回滚?根因是什么?
- 实验追踪的使用率如何?团队成员是否都在使用?
- 数据版本管理和代码版本管理是否同步?
- 下月需要优化哪些流程?
