Python 装饰器
大约 14 分钟约 4134 字
Python 装饰器
简介
装饰器是 Python 中用于在不修改原函数代码的前提下扩展函数行为的高级特性。本质上它是一个接受函数作为参数并返回新函数的高阶函数,广泛应用于日志记录、权限校验、缓存、重试等横切关注点场景。
装饰器的底层原理基于 Python 的"一切皆对象"哲学——函数是一等公民(first-class citizen),可以像普通变量一样被传递、赋值和作为参数。当一个函数被装饰器修饰时,实际上发生了这样的过程:@decorator 语法糖等价于 func = decorator(func),即原函数被传入装饰器,装饰器返回的新函数替代了原函数在当前命名空间中的绑定。这意味着原始函数对象仍然存在(通过闭包引用),但调用者看到的是包装后的版本。
特点
核心原理
装饰器的执行时机与顺序
import functools
# 理解装饰器的执行时机是避免 bug 的关键
print("1. 模块开始加载")
def decorator_a(func):
print(f"2. decorator_a 定义时执行(装饰 {func.__name__})")
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"5. decorator_a wrapper 执行前({func.__name__})")
result = func(*args, **kwargs)
print(f"7. decorator_a wrapper 执行后({func.__name__})")
return result
return wrapper
def decorator_b(func):
print(f"3. decorator_b 定义时执行(装饰 {func.__name__})")
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"4. decorator_b wrapper 执行前({func.__name__})")
result = func(*args, **kwargs)
print(f"6. decorator_b wrapper 执行后({func.__name__})")
return result
return wrapper
@decorator_a
@decorator_b
def my_function():
print(" my_function 函数体执行")
print("4. 模块加载完成")
print("\n--- 调用 my_function() ---")
my_function()
# 执行顺序说明:
# 装饰顺序(定义时):从下到上 — 先 decorator_b,后 decorator_a
# 执行顺序(调用时):从上到下 — 先 decorator_a,后 decorator_b
# 等价于: my_function = decorator_a(decorator_b(my_function))functools.wraps 详解
import functools
def without_wraps(func):
"""不使用 wraps 的装饰器 —— 会丢失原函数的元信息"""
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
def with_wraps(func):
"""使用 wraps 的装饰器 —— 保留原函数的元信息"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@without_wraps
def func_a():
"""这是 func_a 的文档"""
pass
@with_wraps
def func_b():
"""这是 func_b 的文档"""
pass
print(f"不使用 wraps: __name__={func_a.__name__}, __doc__={func_a.__doc__}")
# 不使用 wraps: __name__=wrapper, __doc__=None
print(f"使用 wraps: __name__={func_b.__name__}, __doc__={func_b.__doc__}")
# 使用 wraps: __name__=func_b, __doc__=这是 func_b 的文档
# wraps 做了什么:
# 1. 复制 __name__、__qualname__、__module__、__doc__
# 2. 复制 __dict__(函数属性字典)
# 3. 更新 __wrapped__ 属性,指向原始函数
# 4. 使得 functools.signature() 能正确返回原函数的签名
import inspect
sig = inspect.signature(func_b)
print(f"签名: {sig}") # ()
# 自定义 wraps 行为(更新特定属性)
def custom_wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES):
"""自定义 wraps 的行为"""
def decorator(wrapper):
wrapper.__wrapped__ = wrapped
for attr in assigned:
if hasattr(wrapped, attr):
setattr(wrapper, attr, getattr(wrapped, attr))
for attr in updated:
getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
return wrapper
return decorator实现
基础装饰器:日志与计时
import time
from functools import wraps
def timer(func):
"""计时装饰器:记录函数执行耗时"""
@wraps(func)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
print(f"[{func.__name__}] 耗时: {elapsed:.4f}s")
return result
return wrapper
def log_call(func):
"""日志装饰器:记录函数调用信息"""
@wraps(func)
def wrapper(*args, **kwargs):
args_repr = [repr(a) for a in args]
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
print(f"调用 {func.__name__}({signature})")
result = func(*args, **kwargs)
print(f"{func.__name__} 返回: {result!r}")
return result
return wrapper
@timer
@log_call
def fetch_data(url: str, timeout: int = 30) -> dict:
"""模拟网络请求"""
time.sleep(0.5)
return {"url": url, "status": "ok"}
# 调用: fetch_data("https://api.example.com", timeout=10)
# 输出:
# 调用 fetch_data('https://api.example.com', timeout=10)
# fetch_data 返回: {'url': '...', 'status': 'ok'}
# [fetch_data] 耗时: 0.5012s增强版日志装饰器
import time
import logging
import traceback
from functools import wraps
logger = logging.getLogger(__name__)
def detailed_log(level: str = "info", log_args: bool = True,
log_result: bool = True, log_exception: bool = True):
"""增强版日志装饰器
Args:
level: 日志级别 (debug/info/warning/error)
log_args: 是否记录参数
log_result: 是否记录返回值
log_exception: 是否记录异常堆栈
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
log_fn = getattr(logger, level)
func_name = func.__qualname__
# 记录调用信息
if log_args:
args_str = ", ".join(repr(a) for a in args[1:]) # 跳过 self
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
log_fn(f"调用 {func_name}({all_args})")
# 执行并计时
start = time.perf_counter()
try:
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
if log_result:
result_repr = repr(result)
if len(result_repr) > 200:
result_repr = result_repr[:200] + "..."
log_fn(f"{func_name} 返回 ({elapsed:.3f}s): {result_repr}")
return result
except Exception as e:
elapsed = time.perf_counter() - start
if log_exception:
logger.error(
f"{func_name} 异常 ({elapsed:.3f}s): {e}\n"
f"{traceback.format_exc()}"
)
raise
return wrapper
return decorator
# 使用示例
@Detailed_log(level="debug")
def process_order(order_id: int, items: list) -> dict:
"""处理订单"""
return {"order_id": order_id, "status": "processed", "count": len(items)}带参数的装饰器:重试机制
from functools import wraps
import time
import random
def retry(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0):
"""带参数的重试装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
current_delay = delay
last_exception = None
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_attempts:
print(f"第 {attempt} 次失败: {e},{current_delay:.1f}s 后重试...")
time.sleep(current_delay)
current_delay *= backoff
else:
print(f"第 {attempt} 次失败,已达最大重试次数")
raise last_exception
return wrapper
return decorator
@retry(max_attempts=3, delay=0.5, backoff=2.0)
def call_external_api(endpoint: str) -> dict:
"""模拟不稳定的外部 API 调用"""
if random.random() < 0.7:
raise ConnectionError(f"连接 {endpoint} 失败")
return {"endpoint": endpoint, "data": "success"}
# 自动重试,最多 3 次,延迟递增增强版重试装饰器
import time
import logging
from functools import wraps
from typing import Type, Tuple
logger = logging.getLogger(__name__)
def smart_retry(
max_attempts: int = 3,
delay: float = 1.0,
backoff: float = 2.0,
max_delay: float = 60.0,
exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry=None,
):
"""智能重试装饰器
Args:
max_attempts: 最大重试次数
delay: 初始延迟(秒)
backoff: 退避因子
max_delay: 最大延迟(秒)
exceptions: 需要重试的异常类型元组
on_retry: 每次重试时的回调函数
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
current_delay = delay
last_exception = None
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt >= max_attempts:
logger.error(
f"{func.__name__} 重试 {max_attempts} 次后仍失败: {e}"
)
raise
logger.warning(
f"{func.__name__} 第 {attempt}/{max_attempts} 次失败: {e}, "
f"{current_delay:.1f}s 后重试..."
)
if on_retry:
on_retry(attempt, e)
time.sleep(min(current_delay, max_delay))
current_delay *= backoff
raise last_exception # 理论上不会执行到这里
return wrapper
return decorator
# 使用示例:只重试网络相关异常
@smart_retry(
max_attempts=5,
delay=0.5,
backoff=2.0,
exceptions=(ConnectionError, TimeoutError, OSError),
on_retry=lambda attempt, e: print(f"重试通知: 第 {attempt} 次"),
)
def fetch_from_api(url: str) -> dict:
"""从 API 获取数据"""
import urllib.request
with urllib.request.urlopen(url, timeout=10) as resp:
return json.loads(resp.read())类装饰器与权限校验
from functools import wraps
class RateLimiter:
"""基于类的速率限制装饰器"""
def __init__(self, max_calls: int, period: float):
self.max_calls = max_calls
self.period = period
self.calls = []
def __call__(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
import time
now = time.time()
# 清理过期记录
self.calls = [t for t in self.calls if now - t < self.period]
if len(self.calls) >= self.max_calls:
raise RuntimeError(
f"速率限制: {func.__name__} 在 {self.period}s 内"
f"最多调用 {self.max_calls} 次"
)
self.calls.append(now)
return func(*args, **kwargs)
return wrapper
def require_permission(permission: str):
"""权限校验装饰器"""
def decorator(func):
@wraps(func)
def wrapper(user: dict, *args, **kwargs):
if permission not in user.get("permissions", []):
raise PermissionError(
f"用户 {user.get('name')} 缺少权限: {permission}"
)
return func(user, *args, **kwargs)
return wrapper
return decorator
@RateLimiter(max_calls=5, period=60.0)
@require_permission("admin")
def delete_user(user: dict, target_id: int):
print(f"用户 {user['name']} 删除了 ID={target_id} 的记录")
return True
admin_user = {"name": "张三", "permissions": ["admin", "write"]}
# delete_user(admin_user, 42) # 正常执行
# delete_user({"name": "李四", "permissions": ["read"]}, 42) # PermissionError缓存装饰器
from functools import wraps
import hashlib
import json
def cache_result(ttl: int = 300):
"""简单的内存缓存装饰器,支持 TTL"""
cache_store = {}
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键
key_parts = (func.__name__, args, tuple(sorted(kwargs.items())))
cache_key = hashlib.md5(
json.dumps(key_parts, default=str).encode()
).hexdigest()
import time
now = time.time()
# 检查缓存
if cache_key in cache_store:
cached_time, cached_result = cache_store[cache_key]
if now - cached_time < ttl:
print(f"[缓存命中] {func.__name__}")
return cached_result
# 执行函数并缓存结果
result = func(*args, **kwargs)
cache_store[cache_key] = (now, result)
return result
# 暴露缓存清除方法
wrapper.clear_cache = lambda: cache_store.clear()
return wrapper
return decorator
@cache_result(ttl=60)
def expensive_query(table: str, conditions: dict) -> list:
"""模拟耗时数据库查询"""
import time
time.sleep(2)
return [{"id": i, "table": table} for i in range(10)]
# 第一次调用耗时约 2s,后续 60s 内的调用直接返回缓存使用 functools.lru_cache
import functools
import time
# lru_cache 是 Python 标准库提供的线程安全缓存装饰器
# 适合纯函数(相同输入总是返回相同输出)
@functools.lru_cache(maxsize=128)
def fibonacci(n: int) -> int:
"""递归计算斐波那契数(带缓存)"""
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
# 不带缓存的递归: O(2^n) 时间复杂度
# 带缓存的递归: O(n) 时间复杂度
print(f"fibonacci(100) = {fibonacci(100)}")
print(f"缓存信息: {fibonacci.cache_info()}")
# CacheInfo(hits=98, misses=101, maxsize=128, currsize=101)
# 清除缓存
fibonacci.cache_clear()
# lru_cache 的局限性:
# 1. 只支持可哈希的参数(不能传 list、dict 等)
# 2. 不支持 TTL(过期时间)
# 3. 不支持自定义缓存键生成逻辑
# Python 3.9+ 的 cache 装饰器(无大小限制的 LRU)
@functools.cache
def square(n: int) -> int:
return n * n单例模式装饰器
from functools import wraps
def singleton(cls):
"""单例模式装饰器 —— 确保类只有一个实例"""
instances = {}
@wraps(cls)
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
@singleton
class DatabaseConnection:
def __init__(self, host: str = "localhost", port: int = 5432):
self.host = host
self.port = port
print(f"创建数据库连接: {host}:{port}")
# 无论实例化多少次,返回的都是同一个对象
db1 = DatabaseConnection()
db2 = DatabaseConnection()
print(f"是否同一实例: {db1 is db2}") # True验证与类型检查装饰器
from functools import wraps
from typing import Any, get_type_hints
def validate_types(func):
"""运行时类型检查装饰器"""
type_hints = get_type_hints(func)
@wraps(func)
def wrapper(*args, **kwargs):
import inspect
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for param_name, param_value in bound.arguments.items():
if param_name in type_hints:
expected_type = type_hints[param_name]
if not isinstance(param_value, expected_type):
raise TypeError(
f"{func.__name__}() 参数 '{param_name}' "
f"期望 {expected_type.__name__}, "
f"得到 {type(param_value).__name__}"
)
return func(*args, **kwargs)
return wrapper
@validate_types
def create_user(name: str, age: int, active: bool = True) -> dict:
return {"name": name, "age": age, "active": active}
# create_user("Alice", 30) # OK
# create_user(123, 30) # TypeError: 参数 'name' 期望 str
# create_user("Alice", "30") # TypeError: 参数 'age' 期望 int
def validate_range(**validators):
"""参数范围验证装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
import inspect
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for param_name, (min_val, max_val) in validators.items():
value = bound.arguments[param_name]
if not (min_val <= value <= max_val):
raise ValueError(
f"{func.__name__}() 参数 '{param_name}' "
f"值 {value} 不在范围 [{min_val}, {max_val}]"
)
return func(*args, **kwargs)
return wrapper
return decorator
@validate_range(age=(0, 150), score=(0, 100))
def add_student(name: str, age: int, score: float) -> dict:
return {"name": name, "age": age, "score": score}上下文管理器装饰器
from functools import wraps
import contextlib
def with_timeout(seconds: float):
"""超时控制装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
import signal
def timeout_handler(signum, frame):
raise TimeoutError(f"{func.__name__} 超时 ({seconds}s)")
# 设置信号处理器(仅 Unix 系统有效)
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(int(seconds))
try:
return func(*args, **kwargs)
finally:
signal.alarm(0) # 取消定时器
signal.signal(signal.SIGALRM, old_handler) # 恢复
return wrapper
return decorator
def suppress_exceptions(*exceptions):
"""抑制指定异常的装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions as e:
print(f"抑制异常: {type(e).__name__}: {e}")
return None
return wrapper
return decorator
@suppress_exceptions(ValueError, TypeError)
def safe_parse_int(value: str) -> int | None:
return int(value)
print(safe_parse_int("42")) # 42
print(safe_parse_int("abc")) # 抑制异常: ValueError ... -> None
print(safe_parse_int(None)) # 抑制异常: TypeError ... -> None装饰器注册模式
# 使用装饰器实现命令注册/处理器注册模式
class CommandRegistry:
"""命令注册器"""
def __init__(self):
self._commands = {}
def register(self, name: str):
"""注册命令装饰器"""
def decorator(func):
self._commands[name] = func
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.command_name = name
return wrapper
return decorator
def execute(self, name: str, *args, **kwargs):
if name not in self._commands:
raise KeyError(f"未知命令: {name}")
return self._commands[name](*args, **kwargs)
def list_commands(self):
return list(self._commands.keys())
# 全局注册器
registry = CommandRegistry()
@registry.register("greet")
def greet(name: str) -> str:
return f"你好, {name}!"
@registry.register("calculate")
def calculate(a: float, op: str, b: float) -> float:
ops = {"+": lambda x, y: x + y, "-": lambda x, y: x - y,
"*": lambda x, y: x * y, "/": lambda x, y: x / y}
return ops[op](a, b)
@registry.register("status")
def status() -> dict:
return {"status": "running", "commands": len(registry.list_commands())}
# 使用
print(registry.execute("greet", "世界"))
print(registry.execute("calculate", 10, "+", 20))
print(registry.execute("status"))
print(f"可用命令: {registry.list_commands()}")优点
缺点
总结
装饰器是 Python 中实现横切关注点的核心机制,通过闭包和高阶函数实现函数行为的动态增强。掌握基础装饰器、带参数装饰器和类装饰器三种形式,配合 functools.wraps 保护元信息,就能覆盖绝大多数工程场景。关键是控制使用层级,避免装饰器嵌套过深导致调试困难。
关键知识点
- 装饰器本质是高阶函数:接受函数,返回增强后的新函数
- functools.wraps 必须使用,否则原函数的 name、doc 会丢失
- 带参数的装饰器是三层嵌套函数,最外层接收装饰器参数
- 类装饰器通过 call 方法实现,适合需要维护状态的场景
- 装饰器在模块加载时执行(定义时),包装函数在调用时执行(运行时)
- 多个装饰器叠加时,装饰顺序从下到上,执行顺序从上到下
项目落地视角
- 项目中统一使用装饰器处理日志、权限、缓存等横切逻辑,避免在每个函数里重复编写
- 建议将装饰器集中放在 utils/decorators.py 模块,并编写单元测试覆盖各种边界情况
- 对外提供的装饰器务必编写 docstring 说明参数含义、副作用和异常行为
- 装饰器叠加层数建议不超过 3 层,超过则考虑重构为中间件或 Pipeline 模式
- 使用 functools.lru_cache 替代自定义缓存装饰器(纯函数场景)
常见误区
- 忘记使用 @wraps(func) 导致调试时函数名显示为 wrapper
- 带参数装饰器忘记写三层嵌套,导致调用时参数传递混乱
- 在装饰器中吞掉异常不做处理,导致问题难以排查
- 装饰器内部引入可变全局状态,导致并发安全问题
- 装饰器改变了被装饰函数的签名(如添加了必选参数),破坏调用方代码
- 在类方法上使用装饰器时忘记处理 self 参数
进阶路线
- 学习使用 dataclasses 或类实现更复杂的装饰器状态管理
- 研究 FastAPI、Flask 等框架中装饰器在路由和依赖注入中的应用模式
- 了解 slots、描述符协议与装饰器的结合使用
- 探索 AST 转换和编译期装饰器(宏)在代码生成中的应用
适用场景
- 需要统一处理日志记录、性能监控、权限校验等横切关注点的项目
- Web 框架中路由注册、参数校验、请求预处理等场景
- 数据处理管道中需要添加缓存、重试、超时控制等能力的函数
落地建议
- 建立项目级装饰器库,统一日志格式、异常处理和缓存策略
- 每个装饰器编写独立测试,验证正常路径、异常路径和边界条件
- 在团队 Code Review 中明确装饰器使用规范,避免滥用
- 为复杂装饰器编写使用示例和注意事项文档
排错清单
- 检查是否使用了 @wraps(func) 保留原函数元信息
- 确认装饰器的参数传递是否正确,特别是带参数装饰器的三层嵌套
- 排查多个装饰器叠加时的执行顺序是否符合预期(从下到上执行)
- 检查装饰器中是否有可变默认参数导致的状态共享问题
- 确认 functools.signature 能正确获取被装饰函数的签名
复盘问题
- 你的项目中哪些重复代码可以用装饰器统一抽取?抽取后调试体验是否变差?
- 装饰器引入的状态(如缓存、计数器)在多线程/多进程环境下是否安全?
- 团队成员能否快速理解装饰器的执行流程?是否需要补充文档或示例?
