Python SQLAlchemy ORM
大约 10 分钟约 2904 字
Python SQLAlchemy ORM
简介
SQLAlchemy 是 Python 生态中最成熟的 ORM 框架,提供从底层 SQL 表达到高层 ORM 映射的完整数据库抽象层。SQLAlchemy 2.0 引入了全面支持 async 和类型注解的新 API,是 FastAPI 等现代框架的首选数据库方案。
特点
实现
模型定义与 Session 管理
from sqlalchemy import String, Integer, Boolean, DateTime, ForeignKey, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from datetime import datetime
from typing import Optional, List
class Base(DeclarativeBase):
"""声明式基类"""
pass
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(100))
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
hashed_password: Mapped[str] = mapped_column(String(255))
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[Optional[datetime]] = mapped_column(
DateTime, onupdate=func.now(), nullable=True
)
# 关系
posts: Mapped[List["Post"]] = relationship(back_populates="author", cascade="all, delete-orphan")
def __repr__(self):
return f"User(id={self.id}, name='{self.name}', email='{self.email}')"
class Post(Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200))
content: Mapped[Optional[str]] = mapped_column(String(5000), nullable=True)
published: Mapped[bool] = mapped_column(Boolean, default=False)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
author: Mapped["User"] = relationship(back_populates="posts")
# Session 管理
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
engine = create_engine("postgresql://user:pass@localhost/mydb", echo=False)
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
def get_db() -> Session:
"""依赖注入:获取数据库 Session"""
db = SessionLocal()
try:
yield db
finally:
db.close()
# 创建表
Base.metadata.create_all(bind=engine)CRUD 操作与查询
from sqlalchemy import select, update, delete
from sqlalchemy.orm import Session, joinedload
class UserRepository:
"""用户数据访问层"""
def __init__(self, db: Session):
self.db = db
def create(self, name: str, email: str, hashed_password: str) -> User:
user = User(name=name, email=email, hashed_password=hashed_password)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
return user
def get_by_id(self, user_id: int) -> Optional[User]:
stmt = select(User).where(User.id == user_id)
return self.db.execute(stmt).scalar_one_or_none()
def get_by_email(self, email: str) -> Optional[User]:
stmt = select(User).where(User.email == email)
return self.db.execute(stmt).scalar_one_or_none()
def list_active(self, offset: int = 0, limit: int = 20) -> list[User]:
stmt = (
select(User)
.where(User.is_active == True)
.order_by(User.created_at.desc())
.offset(offset)
.limit(limit)
)
return list(self.db.execute(stmt).scalars().all())
def count_active(self) -> int:
from sqlalchemy import func
stmt = select(func.count()).select_from(User).where(User.is_active == True)
return self.db.execute(stmt).scalar_one()
def update_name(self, user_id: int, new_name: str) -> bool:
stmt = update(User).where(User.id == user_id).values(name=new_name)
result = self.db.execute(stmt)
self.db.commit()
return result.rowcount > 0
def deactivate(self, user_id: int) -> bool:
stmt = update(User).where(User.id == user_id).values(is_active=False)
result = self.db.execute(stmt)
self.db.commit()
return result.rowcount > 0
def delete_by_id(self, user_id: int) -> bool:
stmt = delete(User).where(User.id == user_id)
result = self.db.execute(stmt)
self.db.commit()
return result.rowcount > 0
# 关联查询
class PostRepository:
def __init__(self, db: Session):
self.db = db
def get_with_author(self, post_id: int) -> Optional[Post]:
stmt = (
select(Post)
.options(joinedload(Post.author))
.where(Post.id == post_id)
)
return self.db.execute(stmt).unique().scalar_one_or_none()
def search(self, keyword: str, offset: int = 0, limit: int = 20) -> list[Post]:
stmt = (
select(Post)
.where(Post.title.ilike(f"%{keyword}%"))
.options(joinedload(Post.author))
.order_by(Post.id.desc())
.offset(offset)
.limit(limit)
)
return list(self.db.execute(stmt).unique().scalars().all())异步 ORM 操作
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import select, func
async_engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost/mydb",
echo=False,
pool_size=10,
max_overflow=20,
)
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
class AsyncUserRepository:
"""异步用户数据访问层"""
async def create(self, name: str, email: str, hashed_password: str) -> User:
async with AsyncSessionLocal() as session:
async with session.begin():
user = User(name=name, email=email, hashed_password=hashed_password)
session.add(user)
await session.flush()
return user
async def get_by_id(self, user_id: int) -> Optional[User]:
async with AsyncSessionLocal() as session:
stmt = select(User).where(User.id == user_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_paginated(self, page: int = 1, page_size: int = 20) -> dict:
async with AsyncSessionLocal() as session:
offset = (page - 1) * page_size
# 查询总数
count_stmt = select(func.count()).select_from(User)
total = (await session.execute(count_stmt)).scalar_one()
# 查询数据
stmt = (
select(User)
.order_by(User.created_at.desc())
.offset(offset)
.limit(page_size)
)
items = list((await session.execute(stmt)).scalars().all())
return {
"items": items,
"total": total,
"page": page,
"page_size": page_size,
"pages": (total + page_size - 1) // page_size,
}
# FastAPI 中使用
from fastapi import Depends
async def get_async_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
yield sessionAlembic 数据库迁移
# 初始化迁移环境
alembic init alembic
# 自动生成迁移脚本
alembic revision --autogenerate -m "add users and posts tables"
# 执行迁移
alembic upgrade head
# 回滚一个版本
alembic downgrade -1
# 查看当前版本
alembic current
# 查看迁移历史
alembic history# alembic/env.py 关键配置
from models import Base # 导入你的模型基类
target_metadata = Base.metadata
# 迁移脚本示例: alembic/versions/xxx_add_users_table.py
def upgrade() -> None:
op.create_table(
"users",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("email", sa.String(255), nullable=False, unique=True),
sa.Column("is_active", sa.Boolean(), default=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now()),
)
op.create_index("ix_users_email", "users", ["email"])
def downgrade() -> None:
op.drop_index("ix_users_email")
op.drop_table("users")复杂查询与高级特性
from sqlalchemy import select, func, and_, or_, not_, case, extract
from sqlalchemy.orm import Session, aliased, contains_eager
from datetime import datetime, timedelta
class AdvancedUserRepository:
"""高级查询示例"""
def __init__(self, db: Session):
self.db = db
def search_with_filters(self, keyword: str, status: str = None,
min_age: int = None, page: int = 1,
page_size: int = 20) -> dict:
"""多条件搜索"""
stmt = select(User)
conditions = []
if keyword:
conditions.append(
or_(
User.name.ilike(f"%{keyword}%"),
User.email.ilike(f"%{keyword}%"),
)
)
if status is not None:
conditions.append(User.is_active == (status == "active"))
if min_age is not None:
conditions.append(User.created_at <= datetime.now() - timedelta(days=min_age * 365))
if conditions:
stmt = stmt.where(and_(*conditions))
# 总数查询
count_stmt = select(func.count()).select_from(stmt.subquery())
total = self.db.execute(count_stmt).scalar_one()
# 分页查询
stmt = stmt.order_by(User.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
items = list(self.db.execute(stmt).scalars().all())
return {"items": items, "total": total, "page": page, "page_size": page_size}
def get_user_stats(self) -> dict:
"""聚合统计"""
stmt = select(
func.count(User.id).label("total_users"),
func.sum(case((User.is_active == True, 1), else_=0)).label("active_users"),
func.min(User.created_at).label("earliest_join"),
func.max(User.created_at).label("latest_join"),
func.count(func.distinct(extract("year", User.created_at))).label("join_years"),
)
row = self.db.execute(stmt).one()
return {
"total": row.total_users,
"active": row.active_users,
"earliest_join": row.earliest_join,
"latest_join": row.latest_join,
}
def batch_update_status(self, user_ids: list[int], active: bool) -> int:
"""批量更新"""
stmt = (
update(User)
.where(User.id.in_(user_ids))
.values(is_active=active, updated_at=func.now())
)
result = self.db.execute(stmt)
self.db.commit()
return result.rowcount
def upsert_user(self, user_id: int, name: str, email: str) -> User:
"""存在则更新,不存在则创建"""
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(User).values(
id=user_id, name=name, email=email
).on_conflict_do_update(
index_elements=["id"],
set_={"name": name, "email": email, "updated_at": func.now()}
).returning(User)
user = self.db.execute(stmt).scalar_one()
self.db.commit()
return user关系加载策略与 N+1 优化
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload, selectinload, subqueryload
class PostRepositoryOptimized:
"""优化关系加载,避免 N+1 查询"""
def __init__(self, db: Session):
self.db = db
def list_with_author_joined(self, limit: int = 20) -> list[Post]:
"""joinedload:外键关系使用 JOIN 一次查询"""
stmt = (
select(Post)
.options(joinedload(Post.author))
.order_by(Post.id.desc())
.limit(limit)
)
return list(self.db.execute(stmt).unique().scalars().all())
def list_with_author_selectin(self, limit: int = 20) -> list[Post]:
"""selectinload:集合关系使用 IN 查询"""
stmt = (
select(Post)
.options(selectinload(Post.author))
.order_by(Post.id.desc())
.limit(limit)
)
return list(self.db.execute(stmt).scalars().all())
def get_user_with_posts(self, user_id: int):
"""加载用户及其所有帖子"""
stmt = (
select(User)
.options(selectinload(User.posts))
.where(User.id == user_id)
)
user = self.db.execute(stmt).unique().scalar_one_or_none()
return user
# N+1 问题诊断
def diagnose_n_plus_one(db: Session):
"""打印实际执行的 SQL 数量"""
import logging
logging.basicConfig()
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
# 错误示范:N+1 查询
posts = db.execute(select(Post).limit(10)).scalars().all()
for post in posts:
_ = post.author.name # 每个帖子触发一次额外查询!
# 正确做法:预加载
posts = db.execute(
select(Post).options(joinedload(Post.author)).limit(10)
).unique().scalars().all()
for post in posts:
_ = post.author.name # 不会触发额外查询事务管理与并发控制
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError, OperationalError
from contextlib import contextmanager
from typing import Callable, Any
import logging
logger = logging.getLogger(__name__)
@contextmanager
def transaction_scope(db: Session):
"""事务上下文管理器"""
try:
yield db
db.commit()
except IntegrityError as e:
db.rollback()
logger.error(f"数据完整性错误: {e}")
raise
except OperationalError as e:
db.rollback()
logger.error(f"数据库操作错误: {e}")
raise
except Exception as e:
db.rollback()
logger.error(f"事务失败: {e}")
raise
def transfer_balance(db: Session, from_id: int, to_id: int, amount: float):
"""转账事务示例(带乐观锁)"""
with transaction_scope(db):
# 查询并锁定
from_user = db.execute(
select(User).where(User.id == from_id).with_for_update()
).scalar_one()
to_user = db.execute(
select(User).where(User.id == to_id).with_for_update()
).scalar_one()
if from_user.balance < amount:
raise ValueError("余额不足")
from_user.balance -= amount
to_user.balance += amount
return True
# 使用嵌套事务(savepoint)
def nested_transaction_example(db: Session):
"""嵌套事务示例"""
try:
with db.begin_nested(): # 创建 savepoint
user = User(name="测试用户", email="test@example.com", hashed_password="xxx")
db.add(user)
# 如果后续操作失败,savepoint 回滚,外层事务不受影响
except Exception:
logger.warning("嵌套事务回滚")
# 外层事务继续
db.commit()连接池与性能优化
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool
from sqlalchemy.orm import sessionmaker
from sqlalchemy import event, text
import time
# 连接池配置
engine = create_engine(
"postgresql://user:pass@localhost/mydb",
poolclass=QueuePool,
pool_size=10, # 常驻连接数
max_overflow=20, # 最大溢出连接数
pool_timeout=30, # 获取连接超时(秒)
pool_recycle=3600, # 连接回收时间(秒)
pool_pre_ping=True, # 使用前检查连接有效性
echo_pool=False, # 打印连接池状态
echo=False, # 打印 SQL 语句
)
# 连接池监控
@event.listens_for(engine, "checkout")
def on_checkout(dbapi_connection, connection_record, connection_proxy):
"""连接被借出时"""
connection_record.info['checkout_time'] = time.time()
@event.listens_for(engine, "checkin")
def on_checkin(dbapi_connection, connection_record, connection_proxy):
"""连接被归还时"""
checkout_time = connection_record.info.get('checkout_time', 0)
duration = time.time() - checkout_time
if duration > 5.0:
logger.warning(f"连接使用时间过长: {duration:.2f}s")
# Session 配置
SessionLocal = sessionmaker(
bind=engine,
autocommit=False,
autoflush=False,
expire_on_commit=False, # commit 后不失效已加载的属性
)
def get_pool_status():
"""获取连接池状态"""
pool = engine.pool
print(f"连接池大小: {pool.size()}")
print(f"当前已借出: {pool.checkedout()}")
print(f"当前溢出: {pool.overflow()}")
print(f"总连接数: {pool.size() + pool.overflow()}")优点
缺点
总结
SQLAlchemy 是 Python 数据库操作的工业级标准,2.0 版本的 async 和类型注解支持使其成为现代 Python 项目的首选 ORM。掌握模型定义、Session 管理、查询构建和 Alembic 迁移四大核心能力,即可覆盖大多数项目需求。注意理解 Session 生命周期和关系加载策略,避免 N+1 查询等性能问题。
关键知识点
- DeclarativeBase 是 2.0 推荐的声明式模型基类,Mapped[type] 提供类型注解支持
- Session 管理 HTTP 请求的生命周期:每个请求一个 Session,请求结束关闭
- joinedload 预加载关联数据避免 N+1 查询,selectinload 适用于集合类型关联
- Alembic 管理 schema 迁移,autogenerate 自动检测模型变更
项目落地视角
- 使用 Repository 模式封装数据库操作,业务层不直接操作 Session
- 异步项目使用 asyncpg 驱动 + AsyncSession,连接池大小按并发量配置
- 所有数据库变更通过 Alembic 迁移管理,禁止手动修改数据库 schema
- 慢查询日志集成到监控系统,定期优化 N+1 和全表扫描问题
常见误区
- 在循环中逐条查询关联数据导致 N+1 问题,应使用 joinedload 预加载
- 忘记 commit 或 rollback 导致 Session 状态不一致
- 过度使用 ORM 处理复杂统计查询,应考虑原生 SQL
- Session 跨请求共享导致线程安全问题
进阶路线
- 学习 SQLAlchemy 事件系统(event listen)实现审计日志和自动填充
- 研究只读副本、分库分表在 SQLAlchemy 中的实现方式
- 了解 SQLAlchemy Utils 提供的扩展类型(ArrowType、ChoiceType 等)
- 探索在微服务中使用 SQLAlchemy 管理 CQRS 读写分离
适用场景
- 需要 ORM 抽象的 Web 服务后端(FastAPI、Flask)
- 复杂的数据模型需要关系映射和迁移管理
- 需要同时支持多种数据库的项目
落地建议
- 统一使用 Repository 模式隔离数据库操作,便于测试和替换
- 配置 SQLAlchemy echo=True 在开发环境打印 SQL,便于调试
- 为每个模型编写完整的迁移脚本和回滚脚本
排错清单
- 检查 Session 是否正确关闭,是否有连接泄漏
- 确认关系加载策略是否导致 N+1 查询
- 排查懒加载在异步上下文中报错,应使用 awaitable_attrs 或预加载
复盘问题
- 你的项目中是否存在 N+1 查询问题?如何发现和解决?
- 数据库迁移流程是否规范?回滚是否经过验证?
- 连接池配置是否合理?高峰期是否出现连接超时?
