Python 生成器
Python 生成器
简介
生成器是 Python 中一种特殊的迭代器,通过 yield 关键字实现惰性求值。它不需要一次性将所有数据加载到内存,而是按需逐个产出值,特别适合处理大规模数据流、文件读取和无限序列等场景。
生成器的本质是一个函数,当解释器执行到 yield 语句时,函数的执行会被"冻结"(暂停),并将 yield 后面的表达式值返回给调用者。当下一次调用 __next__() 时,函数会从上次暂停的位置恢复执行,继续运行直到遇到下一个 yield 或 return。这种机制使得生成器能够在多次调用之间保持自身的局部变量状态,而不需要借助类和实例属性来实现。
从底层来看,Python 生成器基于"帧对象"(frame object)实现。每次调用生成器函数时,并不会立即执行函数体,而是返回一个生成器对象(generator object)。这个对象持有一个帧对象,其中保存了代码位置、局部变量表等信息。当调用 next() 或 send() 时,帧对象被恢复执行,执行到 yield 后再次被挂起。
特点
核心原理
生成器与迭代器的关系
Python 的迭代器协议要求对象实现 __iter__() 和 __next__() 两个特殊方法。手动实现迭代器需要编写一个类,而生成器函数则是创建迭代器的简洁语法糖:
from typing import Iterator
# 方式一:手动实现迭代器协议(繁琐)
class CountDown:
"""手动实现的倒计时迭代器"""
def __init__(self, start: int):
self.current = start
def __iter__(self) -> Iterator[int]:
return self
def __next__(self) -> int:
if self.current <= 0:
raise StopIteration
val = self.current
self.current -= 1
return val
# 方式二:生成器函数(简洁,功能完全等价)
def countdown(start: int) -> Iterator[int]:
"""生成器实现的倒计时"""
while start > 0:
yield start
start -= 1
# 两者使用方式完全一致
for i in countdown(5):
print(i, end=" ") # 输出: 5 4 3 2 1
print()
# 验证生成器实现了迭代器协议
gen = countdown(3)
print(type(gen)) # <class 'generator'>
print(hasattr(gen, '__iter__')) # True
print(hasattr(gen, '__next__')) # True生成器的生命周期状态机
生成器对象在生命周期中会经历多个状态,理解这些状态有助于排查问题:
import inspect
def example_gen():
"""用于演示生成器生命周期的生成器"""
print("生成器开始执行")
x = yield 1
print(f"接收到值: {x}")
y = yield 2
print(f"接收到值: {y}")
return "完成"
gen = example_gen()
print(f"创建后状态: {inspect.getgeneratorstate(gen)}") # GEN_CREATED
result = next(gen)
print(f"第一次 next 后状态: {inspect.getgeneratorstate(gen)}") # GEN_SUSPENDED
print(f"产出的值: {result}") # 1
try:
result = gen.send("你好")
print(f"send 后状态: {inspect.getgeneratorstate(gen)}") # GEN_SUSPENDED
except StopIteration as e:
print(f"生成器结束: {e.value}")
try:
gen.send("世界") # 触发 StopIteration
except StopIteration as e:
print(f"生成器返回值: {e.value}") # 完成
print(f"结束状态: {inspect.getgeneratorstate(gen)}") # GEN_CLOSED实现
基础生成器:数据流处理
def read_large_file(file_path: str, chunk_size: int = 8192):
"""逐行读取大文件的生成器"""
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
stripped = line.strip()
if stripped: # 跳过空行
yield stripped
def parse_csv_lines(lines):
"""将 CSV 文本行解析为字典列表"""
header = None
for line in lines:
fields = line.split(",")
if header is None:
header = fields
continue
if len(fields) == len(header):
yield dict(zip(header, fields))
def filter_records(records, key: str, value: str):
"""过滤记录的生成器"""
for record in records:
if record.get(key) == value:
yield record
# 生成器管道:读取 -> 解析 -> 过滤
lines = read_large_file("data/users.csv")
records = parse_csv_lines(lines)
active_users = filter_records(records, "status", "active")
for user in active_users:
print(user) # 内存中始终只保留一条记录生成器表达式与无限序列
import itertools
# 生成器表达式(比列表推导式更省内存)
total = sum(x * x for x in range(1_000_000))
print(f"平方和: {total}")
# 无限序列生成器
def fibonacci():
"""无限斐波那契数列"""
a, b = 0, 1
while True:
yield a
a, b = b, a + b
# 取前 10 个斐波那契数
fibs = list(itertools.islice(fibonacci(), 10))
print(f"前 10 个斐波那契数: {fibs}")
# 批处理生成器
def batch_generator(iterable, batch_size: int):
"""将可迭代对象按批次切分"""
batch = []
for item in iterable:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
# 每批处理 100 条记录
for batch in batch_generator(range(1000), 100):
print(f"处理批次: {len(batch)} 条")send() 与协程式生成器
def accumulator(initial: float = 0):
"""带初始值的累加器生成器"""
total = initial
while True:
value = yield total
if value is None:
continue
total += value
def moving_average(window_size: int):
"""滑动平均生成器"""
window = []
while True:
value = yield
if value is None:
continue
window.append(value)
if len(window) > window_size:
window.pop(0)
avg = sum(window) / len(window)
print(f"窗口: {window}, 均值: {avg:.2f}")
# 使用 send() 与生成器通信
gen = accumulator(100)
next(gen) # 启动生成器,返回 100
print(gen.send(10)) # 110
print(gen.send(20)) # 130
print(gen.send(-5)) # 125
avg_gen = moving_average(3)
next(avg_gen) # 启动生成器
for val in [10, 20, 30, 40, 50]:
avg_gen.send(val)
# 窗口: [10], 均值: 10.00
# 窗口: [10, 20], 均值: 15.00
# 窗口: [10, 20, 30], 均值: 20.00
# 窗口: [20, 30, 40], 均值: 30.00
# 窗口: [30, 40, 50], 均值: 40.00yield from 与子生成器委托
from typing import Generator, Iterable
def flatten(nested: Iterable) -> Generator:
"""递归展平嵌套结构"""
for item in nested:
if isinstance(item, (list, tuple)):
yield from flatten(item) # 委托给子生成器
else:
yield item
def chain(*iterables):
"""串联多个可迭代对象"""
for iterable in iterables:
yield from iterable
# 展平嵌套列表
nested_data = [1, [2, 3], [[4, 5], 6], [7, [8, [9]]]]
flat = list(flatten(nested_data))
print(f"展平结果: {flat}") # [1, 2, 3, 4, 5, 6, 7, 8, 9]
# 串联多个数据源
data1 = [1, 2, 3]
data2 = {"a", "b", "c"}
data3 = range(3)
combined = list(chain(data1, data2, data3))
print(f"串联结果: {combined}")
# 实际场景:多文件合并处理
def process_multiple_files(*file_paths):
for path in file_paths:
yield from read_large_file(path)
def read_large_file(path):
with open(path, "r", encoding="utf-8") as f:
for line in f:
yield line.strip()深入实践
生成器表达式详解
生成器表达式是列表推导式的惰性版本,语法上只是把方括号 [] 换成了圆括号 ():
# 列表推导式 —— 立即计算,全部存入内存
squares_list = [x ** 2 for x in range(1000000)]
print(f"列表内存占用约: {squares_list.__sizeof__() / 1024 / 1024:.2f} MB")
# 生成器表达式 —— 惰性计算,几乎不占内存
squares_gen = (x ** 2 for x in range(1000000))
print(f"生成器内存占用约: {squares_gen.__sizeof__()} bytes")
# 生成器表达式可以直接作为函数参数,省略外层括号
# sum(), max(), min(), any(), all(), sorted() 等都支持
result = sum(x ** 2 for x in range(1000000))
print(f"平方和: {result}")
# 带条件的生成器表达式
even_squares = (x ** 2 for x in range(100) if x % 2 == 0)
print(f"前 5 个偶数的平方: {list(itertools.islice(even_squares, 5))}")
# 生成器表达式的局限性:不能包含复杂的控制流
# 如果逻辑复杂,请使用生成器函数生成器中的异常处理
生成器内部可以捕获和处理异常,也可以通过 throw() 方法从外部向生成器注入异常:
def safe_file_processor(file_path: str):
"""带有异常恢复机制的文件处理生成器"""
try:
with open(file_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
try:
record = process_line(line)
yield record
except ValueError as e:
# 记录错误但继续处理后续行
print(f"第 {line_num} 行解析失败: {e}")
yield {"line": line_num, "error": str(e)}
except FileNotFoundError:
print(f"文件不存在: {file_path}")
return
except IOError as e:
print(f"IO 错误: {e}")
return
def process_line(line: str) -> dict:
"""模拟行处理逻辑"""
data = line.strip().split(",")
if len(data) < 2:
raise ValueError(f"字段不足: 期望至少 2 个,得到 {len(data)} 个")
return {"name": data[0], "value": int(data[1])}
# 使用 throw() 向生成器注入异常
def resilient_processor():
"""可从外部恢复的生成器"""
while True:
try:
data = yield
print(f"处理数据: {data}")
except ValueError:
print("检测到无效数据,已跳过")
except GeneratorExit:
print("生成器被关闭")
raise # GeneratorExit 不能被捕获后忽略
gen = resilient_processor()
next(gen) # 启动
gen.send({"id": 1, "name": "Alice"})
gen.throw(ValueError, "模拟数据异常") # 向生成器注入异常
gen.send({"id": 2, "name": "Bob"})
gen.close() # 触发 GeneratorExitclose() 与资源安全清理
当生成器被提前关闭时(通过 close() 或被垃圾回收),需要确保资源被正确释放:
import contextlib
# 方式一:使用 try/finally 确保资源释放
def db_query_generator(query: str, batch_size: int = 1000):
"""模拟数据库分批查询生成器"""
connection = create_connection() # 假设的数据库连接
cursor = connection.cursor()
try:
cursor.execute(query)
while True:
rows = cursor.fetchmany(batch_size)
if not rows:
break
yield rows
finally:
# 无论生成器是正常结束还是被 close(),finally 都会执行
cursor.close()
connection.close()
print("数据库连接已释放")
def create_connection():
"""模拟创建数据库连接"""
class FakeConnection:
def cursor(self):
return self
def execute(self, query):
pass
def fetchmany(self, size):
return []
def close(self):
pass
return FakeConnection()
# 方式二:使用 contextlib.contextmanager 装饰器
@contextlib.contextmanager
def managed_db_connection():
"""上下文管理器包装的数据库连接"""
conn = create_connection()
try:
yield conn
finally:
conn.close()
# 方式三:结合生成器和上下文管理器
def safe_data_stream(file_path: str):
"""安全的数据流生成器,确保文件关闭"""
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
yield line.strip()
# 即使消费者提前 break,with 语句也会确保文件关闭
# 演示提前关闭
gen = safe_data_stream("example.txt")
next(gen) # 读取第一行
gen.close() # 提前关闭,文件句柄仍会被正确释放生成器的高级类型注解
Python 的 typing 模块提供了 Generator 类型,可以精确描述生成器的输入输出类型:
from typing import Generator, Iterator, Iterable
# Generator[YieldType, SendType, ReturnType]
# YieldType: yield 产出的值的类型
# SendType: send() 传入的值的类型
# ReturnType: 生成器 return 的值的类型
def echo_server() -> Generator[str, str, None]:
"""回显服务器:接收字符串,原样返回"""
while True:
received = yield "等待输入..."
if received:
print(f"收到: {received}")
def range_generator(start: int, end: int) -> Iterator[int]:
"""简单的范围生成器,只产出不接收"""
current = start
while current < end:
yield current
current += 1
def tree_traversal(node: dict) -> Generator[dict, None, int]:
"""树遍历生成器,返回访问的节点总数"""
count = 0
stack = [node]
while stack:
current = stack.pop()
yield current
count += 1
for child in current.get("children", []):
stack.append(child)
return count
# 类型注解的实际应用
def process_stream(
source: Iterable[str]
) -> Generator[dict, None, None]:
"""类型明确的流处理器"""
for line in source:
yield {"raw": line, "length": len(line)}实用模式与生产场景
模式一:数据库分页查询
import time
class PaginationGenerator:
"""数据库分页查询生成器"""
def __init__(self, query_func, page_size: int = 100):
self.query_func = query_func
self.page_size = page_size
def iterate(self, **filters) -> Generator[dict, None, None]:
"""分页迭代,对外部调用者完全透明"""
offset = 0
while True:
results = self.query_func(
limit=self.page_size,
offset=offset,
**filters
)
if not results:
break
yield from results
if len(results) < self.page_size:
break
offset += self.page_size
time.sleep(0.1) # 避免过快请求
# 模拟查询函数
def mock_query(limit: int, offset: int, **filters) -> list:
"""模拟数据库查询"""
data = [
{"id": i, "name": f"用户{i}", "status": "active" if i % 3 else "inactive"}
for i in range(1, 501)
]
# 应用过滤
if "status" in filters:
data = [d for d in data if d["status"] == filters["status"]]
return data[offset:offset + limit]
# 使用:调用方完全不需要知道分页逻辑
paginator = PaginationGenerator(mock_query, page_size=50)
active_count = 0
for user in paginator.iterate(status="active"):
active_count += 1
print(f"活跃用户总数: {active_count}")模式二:ETL 数据管道
from typing import Callable, Any
from datetime import datetime
def extract_lines(file_path: str) -> Generator[str, None, None]:
"""ETL - 提取阶段:从文件读取原始行"""
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
yield line.strip()
def transform(
records: Generator[dict, None, None],
transformations: list[Callable]
) -> Generator[dict, None, None]:
"""ETL - 转换阶段:应用一系列转换函数"""
for record in records:
result = record
for transform_fn in transformations:
result = transform_fn(result)
if result is None: # 转换函数可以过滤掉记录
break
if result is not None:
yield result
def load_batch(
records: Generator[dict, None, None],
batch_size: int,
write_func: Callable
) -> int:
"""ETL - 加载阶段:批量写入目标"""
count = 0
batch = []
for record in records:
batch.append(record)
count += 1
if len(batch) >= batch_size:
write_func(batch)
batch = []
if batch:
write_func(batch)
return count
# 定义转换函数
def add_timestamp(record: dict) -> dict:
"""添加处理时间戳"""
record["processed_at"] = datetime.now().isoformat()
return record
def validate_email(record: dict) -> dict | None:
"""验证邮箱格式,无效则过滤"""
email = record.get("email", "")
if "@" not in email:
return None # 返回 None 表示过滤掉此记录
return record
def anonymize_phone(record: dict) -> dict:
"""手机号脱敏"""
phone = record.get("phone", "")
if len(phone) >= 7:
record["phone"] = phone[:3] + "****" + phone[-4:]
return record
# 组装完整管道
transformations = [add_timestamp, validate_email, anonymize_phone]
def mock_write(batch: list):
"""模拟批量写入"""
print(f"写入 {len(batch)} 条记录")
# 实际使用
# source = extract_lines("data/users.csv")
# parsed = parse_csv_lines(source)
# transformed = transform(parsed, transformations)
# total = load_batch(transformed, batch_size=100, write_func=mock_write)模式三:进度追踪生成器
import sys
from typing import Generator
def progress_generator(
iterable,
total: int | None = None,
prefix: str = "处理进度"
) -> Generator:
"""带进度显示的生成器包装器"""
if total is None:
try:
total = len(iterable)
except TypeError:
total = None
count = 0
start_time = time.time()
for item in iterable:
yield item
count += 1
# 每处理 1000 条或总条数较少时每次都更新进度
if total is not None and (count % 1000 == 0 or total < 2000):
elapsed = time.time() - start_time
rate = count / elapsed if elapsed > 0 else 0
remaining = (total - count) / rate if rate > 0 else 0
pct = count / total * 100
sys.stdout.write(
f"\r{prefix}: {count}/{total} "
f"({pct:.1f}%) "
f"[{elapsed:.1f}s elapsed, "
f"~{remaining:.1f}s remaining] "
)
sys.stdout.flush()
elapsed = time.time() - start_time
print(f"\n{prefix}: 完成! 共 {count} 条, 耗时 {elapsed:.2f}s")
# 使用示例
data = range(50000)
for item in progress_generator(data, total=50000, prefix="数据导入"):
pass # 实际处理逻辑模式四:合并排序多个有序流
import heapq
from typing import Generator
def merge_sorted_streams(
*streams: Generator[int, None, None]
) -> Generator[int, None, None]:
"""合并多个已排序的生成器流,保持有序性
使用堆来高效选择最小元素,时间复杂度 O(n log k),
其中 k 是流的数量,n 是总元素数量。
"""
heap = []
# 初始化堆:从每个流中取第一个元素
for stream_idx, stream in enumerate(streams):
try:
first_item = next(stream)
heapq.heappush(heap, (first_item, stream_idx, stream))
except StopIteration:
pass # 空流直接跳过
while heap:
value, stream_idx, stream = heapq.heappop(heap)
yield value
try:
next_item = next(stream)
heapq.heappush(heap, (next_item, stream_idx, stream))
except StopIteration:
pass # 流已耗尽
# 模拟多个有序数据流
def sorted_stream(start: int, step: int, count: int):
"""生成一个有序整数流"""
for i in range(count):
yield start + i * step
# 三个有序流
s1 = sorted_stream(0, 5, 100) # 0, 5, 10, 15, ...
s2 = sorted_stream(1, 3, 100) # 1, 4, 7, 10, ...
s3 = sorted_stream(2, 7, 100) # 2, 9, 16, 23, ...
# 合并后仍保持有序
merged = list(itertools.islice(merge_sorted_streams(s1, s2, s3), 30))
print(f"合并后前 30 个: {merged}")模式五:生成器实现缓存/缓冲
from collections import deque
import threading
class BufferedGenerator:
"""线程安全的缓冲生成器
生产者在后台线程中向缓冲区写入数据,
消费者通过迭代按需读取,实现生产-消费解耦。
"""
def __init__(self, source, buffer_size: int = 1000):
self.source = source
self.buffer: deque = deque(maxlen=buffer_size)
self exhausted = False
self.lock = threading.Lock()
self.condition = threading.Condition(self.lock)
self.error = None
def _produce(self):
"""生产者线程:从源读取数据放入缓冲区"""
try:
for item in self.source:
with self.condition:
while len(self.buffer) >= self.buffer.maxlen:
self.condition.wait()
self.buffer.append(item)
self.condition.notify()
except Exception as e:
self.error = e
finally:
with self.condition:
self.exhausted = True
self.condition.notify_all()
def __iter__(self):
# 启动生产者线程
thread = threading.Thread(target=self._produce, daemon=True)
thread.start()
while True:
with self.condition:
while not self.buffer and not self.exhausted:
self.condition.wait(timeout=1.0)
if self.error:
raise self.error
if self.buffer:
item = self.buffer.popleft()
self.condition.notify()
yield item
elif self.exhausted:
break
thread.join(timeout=5)
# 使用示例
def slow_data_source():
"""模拟慢速数据源"""
for i in range(100):
time.sleep(0.01) # 模拟耗时操作
yield f"item_{i}"
buffered = BufferedGenerator(slow_data_source(), buffer_size=50)
for item in buffered:
pass # 消费者可以快速消费生成器与 itertools 配合
Python 标准库 itertools 提供了大量与生成器配合使用的工具函数:
import itertools
from itertools import chain, islice, groupby, tee, filterfalse
# 1. chain.from_iterable: 展平一层嵌套
nested = [[1, 2], [3, 4], [5]]
flat = list(chain.from_iterable(nested))
print(f"展平: {flat}") # [1, 2, 3, 4, 5]
# 2. islice: 切片(不创建中间列表)
# 从第 100 个开始,每隔 2 个取 1 个,共取 5 个
sampled = list(islice(range(1000), 100, None, 2))
print(f"采样: {sampled[:5]}") # [100, 102, 104, 106, 108]
# 3. groupby: 分组(需要先排序)
data = [("apple", 3), ("banana", 2), ("apple", 1), ("banana", 5)]
data.sort(key=lambda x: x[0]) # groupby 要求有序
for key, group in groupby(data, key=lambda x: x[0]):
items = list(group)
total = sum(count for _, count in items)
print(f"{key}: {items}, 总计: {total}")
# 4. tee: 将一个迭代器复制为多个独立迭代器
nums = range(5)
a, b, c = tee(nums, 3)
print(f"tee-a: {list(a)}") # [0, 1, 2, 3, 4]
print(f"tee-b: {list(b)}") # [0, 1, 2, 3, 4]
print(f"tee-c: {list(c)}") # [0, 1, 2, 3, 4]
# 5. filterfalse: 反向过滤
data = [1, 2, 3, 4, 5, 6]
odd = filterfalse(lambda x: x % 2 == 0, data)
print(f"奇数: {list(odd)}") # [1, 3, 5]
# 6. accumulate: 累积计算
from itertools import accumulate
running_sum = list(accumulate([1, 2, 3, 4, 5]))
print(f"累积和: {running_sum}") # [1, 3, 6, 10, 15]
running_product = list(accumulate([1, 2, 3, 4], lambda a, b: a * b))
print(f"累积积: {running_product}") # [1, 2, 6, 24]
# 7. takewhile / dropwhile: 条件截断
from itertools import takewhile, dropwhile
data = [1, 2, 5, 3, 8, 1, 2]
head = list(takewhile(lambda x: x < 5, data))
print(f"满足条件的前缀: {head}") # [1, 2]
tail = list(dropwhile(lambda x: x < 5, data))
print(f"跳过满足条件的前缀: {tail}") # [5, 3, 8, 1, 2]
# 8. zip_longest: 不等长可迭代对象的 zip
from itertools import zip_longest
a = [1, 2, 3]
b = ["a", "b"]
print(f"zip_longest: {list(zip_longest(a, b, fillvalue='-'))}")
# [(1, 'a'), (2, 'b'), (3, '-')]性能对比与最佳实践
内存消耗对比
import sys
import tracemalloc
def measure_memory(func, *args, **kwargs):
"""测量函数执行的内存峰值"""
tracemalloc.start()
result = func(*args, **kwargs)
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
return result, peak
# 对比:列表 vs 生成器
def list_approach(n):
"""列表方式:一次性生成所有数据"""
return [x ** 2 for x in range(n)]
def generator_approach(n):
"""生成器方式:惰性生成"""
return (x ** 2 for x in range(n))
n = 10_000_000
_, list_peak = measure_memory(list_approach, n)
print(f"列表方式内存峰值: {list_peak / 1024 / 1024:.2f} MB")
_, gen_peak = measure_memory(generator_approach, n)
print(f"生成器方式内存峰值: {gen_peak / 1024 / 1024:.2f} MB")
# 对比:逐条消费时的差异
def consume_list(n):
"""列表方式逐条消费"""
data = [x ** 2 for x in range(n)]
for item in data:
pass
return len(data)
def consume_generator(n):
"""生成器方式逐条消费"""
data = (x ** 2 for x in range(n))
count = 0
for item in data:
count += 1
return count
_, list_consume_peak = measure_memory(consume_list, n)
_, gen_consume_peak = measure_memory(consume_generator, n)
print(f"\n逐条消费 - 列表: {list_consume_peak / 1024 / 1024:.2f} MB")
print(f"逐条消费 - 生成器: {gen_consume_peak / 1024 / 1024:.2f} MB")何时使用生成器 vs 列表
# 应该使用生成器的场景
def should_use_generator():
# 1. 数据量大,无法一次性装入内存
for record in read_large_file("huge_data.csv"):
process(record)
# 2. 只需要遍历一次
total = sum(x for x in range(1_000_000))
# 3. 数据是无限或未知的
for prime in prime_number_generator():
if prime > 10000:
break
# 4. 构建处理管道
pipeline = filter(lambda x: x > 0, map(int, sys.stdin))
# 应该使用列表的场景
def should_use_list():
# 1. 需要多次遍历
data = [1, 2, 3, 4, 5]
total = sum(data)
average = sum(data) / len(data)
max_val = max(data)
# 2. 需要随机访问
data = [10, 20, 30, 40, 50]
print(data[2]) # 30
# 3. 需要知道长度
data = list(range(100))
print(len(data)) # 100
# 4. 数据量小,内存不是问题
weekdays = ["周一", "周二", "周三", "周四", "周五"]
for day in weekdays:
print(day)常见陷阱与解决方案
陷阱一:重复迭代已耗尽的生成器
# 错误示范
numbers = (x * 2 for x in range(5))
print(list(numbers)) # [0, 2, 4, 6, 8]
print(list(numbers)) # [] -- 生成器已耗尽!
# 解决方案 1:需要多次使用时转为列表
numbers_list = list(x * 2 for x in range(5))
print(list(numbers_list)) # [0, 2, 4, 6, 8]
print(list(numbers_list)) # [0, 2, 4, 6, 8]
# 解决方案 2:使用 itertools.tee 创建独立副本
import itertools
source = (x * 2 for x in range(5))
a, b = itertools.tee(source, 2)
print(list(a)) # [0, 2, 4, 6, 8]
print(list(b)) # [0, 2, 4, 6, 8]
# 注意:tee 内部会缓存已消费的元素,两个副本差距越大内存越高
# 解决方案 3:使用工厂函数,每次创建新的生成器
def number_factory():
return (x * 2 for x in range(5))
print(list(number_factory())) # [0, 2, 4, 6, 8]
print(list(number_factory())) # [0, 2, 4, 6, 8]陷阱二:在生成器中使用可变默认参数
# 错误示范:默认参数在函数定义时创建,生成器之间共享
def buggy_accumulator(values, result=[]):
"""多个生成器实例共享同一个 result 列表"""
for v in values:
result.append(v)
yield sum(result)
# 每次调用都使用同一个默认列表
gen1 = list(buggy_accumulator([1, 2, 3])) # [1, 3, 6]
gen2 = list(buggy_accumulator([10, 20])) # [16, 36] -- 包含了 gen1 的残留!
# 正确做法:使用 None 作为默认值,在函数体内创建新列表
def correct_accumulator(values, result=None):
if result is None:
result = []
for v in values:
result.append(v)
yield sum(result)
gen3 = list(correct_accumulator([1, 2, 3])) # [1, 3, 6]
gen4 = list(correct_accumulator([10, 20])) # [10, 30] -- 独立的陷阱三:yield 在 try/except 中的行为
# 在 Python 3.7+ 中,StopIteration 不再被 except 捕获后传播
# 但 yield 在 try 块中仍有特殊行为需要注意
def tricky_generator():
try:
yield 1
yield 2
raise ValueError("测试异常")
yield 3 # 永远不会执行
except ValueError:
print("捕获了 ValueError")
yield 4 # 可以在 except 中继续 yield
finally:
print("finally 总是执行")
gen = tricky_generator()
print(next(gen)) # 1
print(next(gen)) # 2
print(next(gen)) # 4 (ValueError 被捕获)
try:
next(gen) # StopIteration
except StopIteration:
print("生成器结束")
# 输出顺序:
# 1
# 2
# 捕获了 ValueError
# 4
# finally 总是执行
# 生成器结束优点
缺点
总结
生成器是 Python 处理大数据和流式数据的核心工具,通过 yield 实现惰性求值,将内存消耗从 O(n) 降到 O(1)。掌握生成器表达式、yield from 委托和 send() 双向通信三种模式,就能构建高效的数据处理管道。在项目中优先使用生成器处理文件、数据库游标和网络流等场景。
关键知识点
- yield 会暂停函数执行并保存当前状态,下次调用时从暂停处继续
- yield from 可以将迭代委托给子生成器,简化嵌套遍历代码
- 生成器只能迭代一次,耗尽后需要重新创建才能再次使用
- send() 方法可以向生成器注入值,实现协程式双向通信
- throw() 可以向生成器注入异常,close() 可以提前关闭生成器
- 生成器对象有四种状态:GEN_CREATED、GEN_SUSPENDED、GEN_RUNNING、GEN_CLOSED
项目落地视角
- 处理大文件、大数据集时优先使用生成器管道,避免一次性加载到内存
- 数据库批量查询使用生成器分页获取,结合 yield from 实现透明分页
- 在 ETL 流水线中用生成器串联提取、转换、加载各阶段
- 为复杂生成器编写消费端测试,验证输出序列和边界行为
- 在函数签名中使用 Generator[YieldType, SendType, ReturnType] 明确类型
常见误区
- 把生成器当作列表使用,多次遍历同一个已耗尽的生成器
- 在生成器中遗漏 return 或 break 导致无限循环
- 在 yield 表达式中忽略 send() 传入的 None 值
- 过度使用生成器导致代码可读性下降,简单场景应该用列表
- 在生成器中使用可变默认参数导致多个实例共享状态
- 忘记在生成器中使用 try/finally 确保资源释放
进阶路线
- 学习 async generator(异步生成器)处理异步数据流
- 研究 itertools 标准库中的生成器工具函数
- 了解生成器在协程(asyncio)演进历史中的角色
- 探索基于生成器的状态机实现模式
- 掌握 contextlib.contextmanager 和 closing 等生成器工具
适用场景
- 处理 GB 级大文件或数据库导出,需要逐行处理避免内存溢出
- 构建多阶段数据处理管道(ETL),各阶段解耦且内存可控
- 实现无限序列或按需计算的懒加载模式
- 合并多个有序数据流(外部排序的归并阶段)
- 实现生产者-消费者模式中的缓冲区
落地建议
- 在项目工具模块中提供通用的分批、过滤、映射生成器工具
- 为生成器管道添加类型注解,使用 Generator[YieldType, SendType, ReturnType]
- 在文档中明确标注哪些函数返回生成器,提醒调用方注意一次性消费
- 为关键生成器添加进度追踪和日志,方便排查数据处理问题
- 编写单元测试时,使用 list() 消费生成器后对结果做断言
排错清单
- 检查生成器是否被意外消耗(如用 list() 转换后再次迭代)
- 确认 yield from 委托的子生成器是否正确处理了异常
- 排查生成器内部是否有未处理的 StopIteration 异常
- 检查生成器是否在 try/finally 中正确释放了资源
- 使用 inspect.getgeneratorstate() 检查生成器的当前状态
- 确认 send() 的第一次调用是否使用了 next() 预热
复盘问题
- 你的项目中哪些一次性加载大数据的场景可以改用生成器?改完后内存占用变化如何?
- 生成器管道中的某个环节抛出异常,如何保证已打开的资源(文件、连接)被正确释放?
- 团队成员是否理解生成器只能迭代一次的特性?是否有过因误解导致的 bug?
- 你的 ETL 管道是否使用了生成器?能否通过生成器管道降低峰值内存?
