RAG 生产环境最佳实践
大约 19 分钟约 5806 字
RAG 生产环境最佳实践
简介
检索增强生成(Retrieval-Augmented Generation, RAG)是将大语言模型与外部知识库结合的技术方案,通过在推理时检索相关文档来增强模型回答的准确性和时效性。然而,将 RAG 从原型推向生产环境需要解决文档处理、检索质量、响应延迟、系统监控等一系列工程挑战。
本文将系统性地介绍 RAG 系统在生产环境中的最佳实践,涵盖从文档处理到最终响应的完整链路,帮助团队构建高质量、高性能、可运维的 RAG 应用。
特点
RAG 生产系统的核心特征:
- 知识时效性: 通过外部知识库弥补模型训练数据的时效性不足
- 事实可溯源: 回答可以追溯到具体文档来源,提升可信度
- 领域适应性: 无需重新训练模型即可适配垂直领域知识
- 成本可控: 相比微调,知识更新成本更低
- 工程复杂度: 涉及文档处理、向量检索、Prompt 组装等多个工程环节
实现
1. 文档处理流水线
1.1 多格式文档解析器
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@dataclass
class Document:
"""统一文档结构"""
content: str
metadata: dict
source: str
doc_type: str
page_number: Optional[int] = None
title: Optional[str] = None
created_at: Optional[str] = None
class DocumentParser(ABC):
"""文档解析器基类"""
@abstractmethod
def parse(self, file_path: str) -> list[Document]:
pass
def clean_text(self, text: str) -> str:
"""通用文本清洗"""
# 移除多余空白
text = re.sub(r'\s+', ' ', text)
# 移除特殊字符(保留中文、英文、数字、基本标点)
text = re.sub(r'[^\w\s\u4e00-\u9fff.,;:!?()()、。,;:!?\-""''·]', '', text)
# 移除连续重复标点
text = re.sub(r'([.,;:!?]){3,}', r'\1\1', text)
return text.strip()
class TextParser(DocumentParser):
"""纯文本解析器"""
def parse(self, file_path: str) -> list[Document]:
content = Path(file_path).read_text(encoding="utf-8")
content = self.clean_text(content)
return [Document(
content=content,
metadata={"file_path": file_path, "size": len(content)},
source=file_path,
doc_type="txt",
)]
class MarkdownParser(DocumentParser):
"""Markdown 解析器"""
def parse(self, file_path: str) -> list[Document]:
content = Path(file_path).read_text(encoding="utf-8")
# 按标题分割为多个部分
sections = re.split(r'(^#{1,6}\s+.+$)', content, flags=re.MULTILINE)
documents = []
current_title = Path(file_path).stem
for i, section in enumerate(sections):
if re.match(r'^#{1,6}\s+', section):
current_title = re.sub(r'^#{1,6}\s+', '', section).strip()
elif section.strip():
cleaned = self.clean_text(section)
if len(cleaned) > 50: # 过滤过短的内容
documents.append(Document(
content=cleaned,
metadata={
"file_path": file_path,
"section_title": current_title,
},
source=file_path,
doc_type="markdown",
title=current_title,
))
return documents
class PDFParser(DocumentParser):
"""PDF 解析器(使用 PyMuPDF)"""
def parse(self, file_path: str) -> list[Document]:
import fitz # PyMuPDF
documents = []
pdf = fitz.open(file_path)
for page_num in range(len(pdf)):
page = pdf[page_num]
# 提取文本
text = page.get_text()
# 提取表格(如果存在)
tables = page.find_tables()
table_texts = []
for table in tables:
df = table.to_pandas()
table_texts.append(df.to_string(index=False))
# 合并文本和表格
full_text = text
if table_texts:
full_text += "\n\n--- 表格 ---\n" + "\n".join(table_texts)
cleaned = self.clean_text(full_text)
if len(cleaned) > 20:
documents.append(Document(
content=cleaned,
metadata={
"file_path": file_path,
"page": page_num + 1,
"total_pages": len(pdf),
},
source=file_path,
doc_type="pdf",
page_number=page_num + 1,
))
pdf.close()
return documents
class HTMLParser(DocumentParser):
"""HTML 解析器"""
def parse(self, file_path: str) -> list[Document]:
from bs4 import BeautifulSoup
content = Path(file_path).read_text(encoding="utf-8")
soup = BeautifulSoup(content, "html.parser")
# 移除无关标签
for tag in soup(["script", "style", "nav", "footer", "header"]):
tag.decompose()
# 提取主体内容
main_content = soup.find("main") or soup.find("article") or soup.find("body")
text = main_content.get_text(separator="\n") if main_content else soup.get_text()
cleaned = self.clean_text(text)
title = soup.find("title")
title_text = title.get_text() if title else Path(file_path).stem
return [Document(
content=cleaned,
metadata={"file_path": file_path, "title": title_text},
source=file_path,
doc_type="html",
title=title_text,
)]
class DocumentProcessor:
"""文档处理流水线"""
PARSER_MAP = {
".txt": TextParser,
".md": MarkdownParser,
".pdf": PDFParser,
".html": HTMLParser,
".htm": HTMLParser,
}
def __init__(self):
self.parsers: dict[str, DocumentParser] = {}
def get_parser(self, file_extension: str) -> Optional[DocumentParser]:
if file_extension not in self.parsers:
parser_cls = self.PARSER_MAP.get(file_extension)
if parser_cls:
self.parsers[file_extension] = parser_cls()
else:
return None
return self.parsers[file_extension]
def process_file(self, file_path: str) -> list[Document]:
"""处理单个文件"""
ext = Path(file_path).suffix.lower()
parser = self.get_parser(ext)
if not parser:
raise ValueError(f"不支持的文件格式: {ext}")
return parser.parse(file_path)
def process_directory(self, dir_path: str) -> list[Document]:
"""处理目录下所有文件"""
all_documents = []
path = Path(dir_path)
for file_path in path.rglob("*"):
if file_path.is_file() and file_path.suffix.lower() in self.PARSER_MAP:
try:
docs = self.process_file(str(file_path))
all_documents.extend(docs)
print(f" 已处理: {file_path} ({len(docs)} 个文档块)")
except Exception as e:
print(f" 处理失败: {file_path} - {e}")
return all_documents2. 文档分块策略
2.1 多种分块策略实现
from dataclasses import dataclass
from typing import Callable
@dataclass
class Chunk:
"""文档块"""
content: str
metadata: dict
chunk_id: str
doc_source: str
start_char: int = 0
end_char: int = 0
overlap_tokens: int = 0
class ChunkingStrategy:
"""文档分块策略"""
@staticmethod
def fixed_size(
text: str,
chunk_size: int = 500,
overlap: int = 50,
) -> list[str]:
"""固定大小分块"""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk_text = text[start:end]
# 尝试在句子边界切分
if end < len(text):
last_period = chunk_text.rfind('。')
last_newline = chunk_text.rfind('\n')
split_point = max(last_period, last_newline)
if split_point > chunk_size // 2:
chunk_text = text[start:start + split_point + 1]
end = start + split_point + 1
chunks.append(chunk_text)
start = end - overlap
return chunks
@staticmethod
def sentence_based(
text: str,
max_chunk_size: int = 500,
min_chunk_size: int = 100,
) -> list[str]:
"""基于句子的分块"""
# 中文句子分割
sentences = re.split(r'(?<=[。!?;\n])', text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = []
for sentence in sentences:
current_chunk.append(sentence)
chunk_text = ''.join(current_chunk)
if len(chunk_text) >= max_chunk_size:
chunks.append(chunk_text)
current_chunk = []
# 处理剩余内容
if current_chunk:
remaining = ''.join(current_chunk)
if chunks and len(remaining) < min_chunk_size:
chunks[-1] += remaining
else:
chunks.append(remaining)
return chunks
@staticmethod
def semantic_chunking(
text: str,
max_chunk_size: int = 800,
similarity_threshold: float = 0.5,
) -> list[str]:
"""语义分块(基于段落主题相似度)"""
# 先按段落分割
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
if not paragraphs:
return [text] if text.strip() else []
# 简化版: 基于段落长度和关键词重叠的合并策略
chunks = []
current_group = [paragraphs[0]]
current_size = len(paragraphs[0])
for i in range(1, len(paragraphs)):
para = paragraphs[i]
# 判断是否应该开始新块
should_split = (
current_size + len(para) > max_chunk_size
or i > 0 and not ChunkingStrategy._has_keyword_overlap(
current_group[-1], para
)
)
if should_split:
chunks.append('\n\n'.join(current_group))
current_group = [para]
current_size = len(para)
else:
current_group.append(para)
current_size += len(para)
if current_group:
chunks.append('\n\n'.join(current_group))
return chunks
@staticmethod
def _has_keyword_overlap(text1: str, text2: str) -> bool:
"""检查两段文本是否有关键词重叠"""
# 简化版: 提取名词性关键词
stopwords = {'的', '了', '是', '在', '和', '与', '及', '等', '为', '中',
'对', '也', '都', '就', '不', '有', '而', '被', '从', '到'}
words1 = set(text1) - stopwords
words2 = set(text2) - stopwords
overlap = words1 & words2
return len(overlap) > min(len(words1), len(words2)) * 0.1
@staticmethod
def recursive_chunking(
text: str,
max_chunk_size: int = 500,
separators: list[str] = None,
) -> list[str]:
"""递归字符分块(LangChain 风格)"""
if separators is None:
separators = ["\n\n", "\n", "。", ";", " ", ""]
final_chunks = []
def _split(text: str, sep_idx: int):
if len(text) <= max_chunk_size:
final_chunks.append(text)
return
sep = separators[sep_idx] if sep_idx < len(separators) else ""
if sep == "":
# 最后手段: 强制切分
for i in range(0, len(text), max_chunk_size):
final_chunks.append(text[i:i + max_chunk_size])
return
parts = text.split(sep)
current = ""
for part in parts:
if len(current) + len(part) + len(sep) > max_chunk_size:
if current:
final_chunks.append(current)
current = part
else:
current = current + sep + part if current else part
if current:
if len(current) > max_chunk_size and sep_idx + 1 < len(separators):
_split(current, sep_idx + 1)
else:
final_chunks.append(current)
_split(text, 0)
return final_chunks
class DocumentChunker:
"""文档分块器"""
def __init__(
self,
strategy: str = "recursive",
chunk_size: int = 500,
overlap: int = 50,
):
self.strategy = strategy
self.chunk_size = chunk_size
self.overlap = overlap
def chunk_documents(self, documents: list[Document]) -> list[Chunk]:
"""将文档列表切分为块"""
all_chunks = []
for doc in documents:
if self.strategy == "fixed":
text_chunks = ChunkingStrategy.fixed_size(
doc.content, self.chunk_size, self.overlap
)
elif self.strategy == "sentence":
text_chunks = ChunkingStrategy.sentence_based(
doc.content, self.chunk_size
)
elif self.strategy == "semantic":
text_chunks = ChunkingStrategy.semantic_chunking(
doc.content, self.chunk_size
)
elif self.strategy == "recursive":
text_chunks = ChunkingStrategy.recursive_chunking(
doc.content, self.chunk_size
)
else:
text_chunks = ChunkingStrategy.fixed_size(
doc.content, self.chunk_size, self.overlap
)
for idx, chunk_text in enumerate(text_chunks):
chunk_id = f"{doc.source}_{idx}"
all_chunks.append(Chunk(
content=chunk_text,
metadata={**doc.metadata, "chunk_index": idx},
chunk_id=chunk_id,
doc_source=doc.source,
))
return all_chunks3. Embedding 模型选择与向量存储
import hashlib
from abc import ABC, abstractmethod
class EmbeddingProvider(ABC):
"""Embedding 提供者基类"""
@abstractmethod
def embed(self, text: str) -> list[float]:
pass
@abstractmethod
def embed_batch(self, texts: list[str]) -> list[list[float]]:
pass
@abstractmethod
def get_dimension(self) -> int:
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI Embedding"""
def __init__(self, model: str = "text-embedding-3-small"):
self.model = model
self.client = OpenAI()
def embed(self, text: str) -> list[float]:
response = self.client.embeddings.create(
input=text, model=self.model
)
return response.data[0].embedding
def embed_batch(self, texts: list[str]) -> list[list[float]]:
# OpenAI 批量限制
batch_size = 100
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = self.client.embeddings.create(
input=batch, model=self.model
)
all_embeddings.extend([d.embedding for d in response.data])
return all_embeddings
def get_dimension(self) -> int:
dims = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
return dims.get(self.model, 1536)
class LocalEmbeddingProvider(EmbeddingProvider):
"""本地 Embedding (sentence-transformers)"""
def __init__(self, model_name: str = "BAAI/bge-small-zh-v1.5"):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name)
def embed(self, text: str) -> list[float]:
return self.model.encode(text).tolist()
def embed_batch(self, texts: list[str]) -> list[list[float]]:
embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=False)
return embeddings.tolist()
def get_dimension(self) -> int:
return self.model.get_sentence_embedding_dimension()
class VectorStore(ABC):
"""向量存储基类"""
@abstractmethod
def add(self, chunks: list[Chunk], embeddings: list[list[float]]):
pass
@abstractmethod
def search(self, query_embedding: list[float], top_k: int = 5) -> list[dict]:
pass
class ChromaVectorStore(VectorStore):
"""ChromaDB 向量存储"""
def __init__(self, collection_name: str = "rag_docs", persist_dir: str = "./chroma_db"):
import chromadb
self.client = chromadb.PersistentClient(path=persist_dir)
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"},
)
def add(self, chunks: list[Chunk], embeddings: list[list[float]]):
ids = [chunk.chunk_id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
# 分批添加
batch_size = 100
for i in range(0, len(ids), batch_size):
self.collection.add(
ids=ids[i:i + batch_size],
documents=documents[i:i + batch_size],
embeddings=embeddings[i:i + batch_size],
metadatas=metadatas[i:i + batch_size],
)
def search(self, query_embedding: list[float], top_k: int = 5) -> list[dict]:
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"],
)
search_results = []
for i in range(len(results["ids"][0])):
search_results.append({
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
"score": 1 - results["distances"][0][i], # cosine distance -> similarity
})
return search_results
def delete_collection(self):
self.client.delete_collection(self.collection.name)4. 检索质量优化
4.1 混合检索(关键词 + 向量)
from dataclasses import dataclass
@dataclass
class SearchResult:
"""统一搜索结果"""
content: str
metadata: dict
score: float
source: str = ""
retrieval_method: str = "" # "vector", "keyword", "hybrid"
class HybridSearchEngine:
"""混合检索引擎"""
def __init__(
self,
vector_store: VectorStore,
embedding_provider: EmbeddingProvider,
bm25_weight: float = 0.4,
vector_weight: float = 0.6,
):
self.vector_store = vector_store
self.embedding_provider = embedding_provider
self.bm25_weight = bm25_weight
self.vector_weight = vector_weight
self.bm25_index = None
self.documents = []
def build_bm25_index(self, documents: list[str]):
"""构建 BM25 关键词索引"""
try:
from rank_bm25 import BM25Okapi
import jieba
tokenized = [list(jieba.cut(doc)) for doc in documents]
self.bm25_index = BM25Okapi(tokenized)
self.documents = documents
except ImportError:
print("请安装 rank_bm25 和 jieba: pip install rank-bm25 jieba")
def _bm25_search(self, query: str, top_k: int = 10) -> list[SearchResult]:
"""BM25 关键词检索"""
import jieba
if not self.bm25_index:
return []
query_tokens = list(jieba.cut(query))
scores = self.bm25_index.get_scores(query_tokens)
# 归一化分数
max_score = max(scores) if scores.max() > 0 else 1
normalized_scores = scores / max_score
top_indices = normalized_scores.argsort()[-top_k:][::-1]
results = []
for idx in top_indices:
if normalized_scores[idx] > 0.1:
results.append(SearchResult(
content=self.documents[idx],
metadata={"index": int(idx)},
score=float(normalized_scores[idx]),
retrieval_method="keyword",
))
return results
def _vector_search(self, query: str, top_k: int = 10) -> list[SearchResult]:
"""向量检索"""
query_embedding = self.embedding_provider.embed(query)
raw_results = self.vector_store.search(query_embedding, top_k=top_k)
return [
SearchResult(
content=r["content"],
metadata=r.get("metadata", {}),
score=r["score"],
retrieval_method="vector",
)
for r in raw_results
]
def search(
self,
query: str,
top_k: int = 5,
method: str = "hybrid",
) -> list[SearchResult]:
"""
混合检索
Args:
query: 查询文本
top_k: 返回结果数量
method: 检索方法 ("vector", "keyword", "hybrid")
"""
if method == "vector":
return self._vector_search(query, top_k)
elif method == "keyword":
return self._bm25_search(query, top_k)
elif method == "hybrid":
# 分别检索
vector_results = self._vector_search(query, top_k * 2)
keyword_results = self._bm25_search(query, top_k * 2)
# 合并和重排序
return self._reciprocal_rank_fusion(
vector_results, keyword_results, top_k
)
return []
def _reciprocal_rank_fusion(
self,
vector_results: list[SearchResult],
keyword_results: list[SearchResult],
top_k: int,
k: int = 60,
) -> list[SearchResult]:
"""倒数排名融合(RRF)"""
scores = {}
# 向量检索分数
for rank, result in enumerate(vector_results):
content_key = hashlib.md5(result.content.encode()).hexdigest()
rrf_score = self.vector_weight / (k + rank + 1)
if content_key not in scores:
scores[content_key] = {
"content": result.content,
"metadata": result.metadata,
"score": 0.0,
}
scores[content_key]["score"] += rrf_score
# 关键词检索分数
for rank, result in enumerate(keyword_results):
content_key = hashlib.md5(result.content.encode()).hexdigest()
rrf_score = self.bm25_weight / (k + rank + 1)
if content_key not in scores:
scores[content_key] = {
"content": result.content,
"metadata": result.metadata,
"score": 0.0,
}
scores[content_key]["score"] += rrf_score
# 按分数排序
sorted_results = sorted(scores.values(), key=lambda x: x["score"], reverse=True)
return [
SearchResult(
content=r["content"],
metadata=r["metadata"],
score=r["score"],
retrieval_method="hybrid",
)
for r in sorted_results[:top_k]
]4.2 查询扩展与重写
class QueryExpander:
"""查询扩展与重写"""
def __init__(self):
self.client = OpenAI()
def expand_query(self, original_query: str, num_expansions: int = 3) -> list[str]:
"""生成查询的多种表述"""
prompt = f"""请将以下查询重写为 {num_expansions} 种不同的表述,保持相同的语义但使用不同的词汇和句式。
每行一个重写结果,不要编号。
原始查询: {original_query}"""
completion = self.client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
)
expansions = completion.choices[0].message.content.strip().split("\n")
expansions = [e.strip() for e in expansions if e.strip()]
return [original_query] + expansions[:num_expansions]
def decompose_query(self, complex_query: str) -> list[str]:
"""将复杂查询分解为子查询"""
prompt = f"""将以下复杂查询分解为多个简单的子查询,每行一个子查询:
复杂查询: {complex_query}"""
completion = self.client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
)
sub_queries = completion.choices[0].message.content.strip().split("\n")
return [q.strip().lstrip("0123456789.-) ") for q in sub_queries if q.strip()]
def hyde_query(self, query: str) -> str:
"""假设性文档嵌入(HyDE) - 生成假设性回答用于检索"""
prompt = f"""请为以下问题写一个详细的回答(即使你不完全确定,也请尽力回答):
问题: {query}"""
completion = self.client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
)
return completion.choices[0].message.content4.3 重排序(Reranking)
class Reranker:
"""检索结果重排序"""
def __init__(self, method: str = "cross_encoder"):
self.method = method
self.client = OpenAI()
def rerank_with_llm(
self,
query: str,
results: list[SearchResult],
top_k: int = 5,
) -> list[SearchResult]:
"""使用 LLM 进行重排序"""
if not results:
return results
docs_text = "\n\n".join([
f"[文档{i+1}]: {r.content[:300]}"
for i, r in enumerate(results)
])
prompt = f"""请根据查询的相关性对以下文档进行重新排序。
查询: {query}
{docs_text}
请输出最相关的 {top_k} 个文档的编号(从1开始),按相关度从高到低排列,用逗号分隔。
只输出编号,不要其他内容。"""
completion = self.client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
)
# 解析结果
response = completion.choices[0].message.content.strip()
try:
indices = [
int(x.strip()) - 1
for x in response.split(",")
if x.strip().isdigit()
]
reranked = []
for idx in indices:
if 0 <= idx < len(results):
results[idx].score = 1.0 - (len(reranked) * 0.1)
results[idx].retrieval_method = "reranked"
reranked.append(results[idx])
# 补充未排到的结果
ranked_set = set(indices)
for i, r in enumerate(results):
if i not in ranked_set and len(reranked) < top_k:
reranked.append(r)
return reranked[:top_k]
except (ValueError, IndexError):
return results[:top_k]
def rerank_with_cross_encoder(
self,
query: str,
results: list[SearchResult],
top_k: int = 5,
) -> list[SearchResult]:
"""使用交叉编码器重排序"""
try:
from sentence_transformers import CrossEncoder
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
pairs = [(query, r.content) for r in results]
scores = model.predict(pairs)
# 按分数排序
scored_results = list(zip(results, scores))
scored_results.sort(key=lambda x: x[1], reverse=True)
for i, (result, score) in enumerate(scored_results[:top_k]):
result.score = float(score)
result.retrieval_method = "reranked"
return [r for r, s in scored_results[:top_k]]
except ImportError:
print("请安装 sentence-transformers")
return results[:top_k]5. RAG 完整管线
from dataclasses import dataclass
from typing import AsyncIterator
@dataclass
class RAGResponse:
"""RAG 响应"""
answer: str
sources: list[dict]
query: str
retrieval_count: int
latency_ms: float
class RAGPipeline:
"""RAG 完整管线"""
def __init__(
self,
embedding_provider: EmbeddingProvider,
vector_store: VectorStore,
model: str = "gpt-4",
chunk_size: int = 500,
retrieval_top_k: int = 5,
rerank_top_k: int = 3,
):
self.embedding_provider = embedding_provider
self.vector_store = vector_store
self.model = model
self.retrieval_top_k = retrieval_top_k
self.rerank_top_k = rerank_top_k
self.client = OpenAI()
self.query_expander = QueryExpander()
self.reranker = Reranker()
def ingest(self, file_path: str):
"""摄入文档"""
# 1. 解析文档
processor = DocumentProcessor()
documents = processor.process_file(file_path)
print(f"解析完成: {len(documents)} 个文档段落")
# 2. 分块
chunker = DocumentChunker(chunk_size=self.chunk_size)
chunks = chunker.chunk_documents(documents)
print(f"分块完成: {len(chunks)} 个文档块")
# 3. 生成 Embedding
texts = [chunk.content for chunk in chunks]
embeddings = self.embedding_provider.embed_batch(texts)
print(f"Embedding 完成: {len(embeddings)} 个向量")
# 4. 存入向量数据库
self.vector_store.add(chunks, embeddings)
print(f"存储完成")
def query(
self,
user_query: str,
use_expansion: bool = False,
use_reranking: bool = True,
) -> RAGResponse:
"""执行 RAG 查询"""
import time
start_time = time.time()
# 1. 查询扩展(可选)
if use_expansion:
expanded_queries = self.query_expander.expand_query(user_query, num_expansions=2)
else:
expanded_queries = [user_query]
# 2. 检索
all_results = []
for q in expanded_queries:
query_embedding = self.embedding_provider.embed(q)
results = self.vector_store.search(query_embedding, top_k=self.retrieval_top_k)
all_results.extend(results)
# 去重
seen = set()
unique_results = []
for r in all_results:
key = hashlib.md5(r["content"].encode()).hexdigest()
if key not in seen:
seen.add(key)
unique_results.append(SearchResult(
content=r["content"],
metadata=r.get("metadata", {}),
score=r.get("score", 0),
))
# 3. 重排序(可选)
if use_reranking and unique_results:
ranked_results = self.reranker.rerank_with_llm(
user_query, unique_results, top_k=self.rerank_top_k
)
else:
ranked_results = sorted(
unique_results, key=lambda x: x.score, reverse=True
)[:self.rerank_top_k]
# 4. 构建 Prompt
context_parts = []
sources = []
for i, result in enumerate(ranked_results):
context_parts.append(f"[来源{i+1}] {result.content}")
sources.append({
"content": result.content[:200],
"metadata": result.metadata,
"score": result.score,
})
context = "\n\n".join(context_parts)
system_prompt = """你是一个专业的知识助手。请严格根据提供的参考资料回答用户问题。
要求:
1. 只基于提供的参考资料回答,不要编造信息
2. 如果参考资料中没有相关内容,明确告知用户
3. 引用信息时标注来源编号
4. 回答要准确、完整、有条理"""
user_prompt = f"""参考资料:
{context}
用户问题: {user_query}
请基于以上参考资料回答问题。"""
# 5. 生成回答
completion = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.0,
)
answer = completion.choices[0].message.content
latency_ms = (time.time() - start_time) * 1000
return RAGResponse(
answer=answer,
sources=sources,
query=user_query,
retrieval_count=len(ranked_results),
latency_ms=latency_ms,
)
async def query_stream(
self,
user_query: str,
use_expansion: bool = False,
use_reranking: bool = True,
) -> AsyncIterator[str]:
"""流式 RAG 查询"""
# 检索阶段(同步)
query_embedding = self.embedding_provider.embed(user_query)
results = self.vector_store.search(query_embedding, top_k=self.retrieval_top_k)
if use_reranking and results:
search_results = [
SearchResult(content=r["content"], metadata=r.get("metadata", {}),
score=r.get("score", 0))
for r in results
]
ranked = self.reranker.rerank_with_llm(user_query, search_results, self.rerank_top_k)
context = "\n\n".join([f"[来源{i+1}] {r.content}" for i, r in enumerate(ranked)])
else:
context = "\n\n".join([f"[来源{i+1}] {r['content']}" for i, r in enumerate(results[:self.rerank_top_k])])
# 流式生成
stream = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "你是一个专业的知识助手。请严格根据提供的参考资料回答。"},
{"role": "user", "content": f"参考资料:\n{context}\n\n问题: {user_query}"},
],
temperature=0.0,
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content6. RAG 监控与调试
import json
import time
from dataclasses import dataclass, field
from datetime import datetime
@dataclass
class RAGLogEntry:
"""RAG 查询日志"""
query_id: str
query: str
timestamp: str
retrieval_results: list[dict] = field(default_factory=list)
retrieval_latency_ms: float = 0
generation_latency_ms: float = 0
total_latency_ms: float = 0
answer: str = ""
sources_count: int = 0
top_score: float = 0
avg_score: float = 0
feedback: str = "" # positive/negative/neutral
class RAGMonitor:
"""RAG 系统监控器"""
def __init__(self, log_dir: str = "rag_logs"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
self.metrics = {
"total_queries": 0,
"avg_latency": 0,
"avg_retrieval_score": 0,
"positive_feedback_rate": 0,
}
def log_query(self, entry: RAGLogEntry):
"""记录查询日志"""
log_file = self.log_dir / f"rag_{datetime.now().strftime('%Y%m%d')}.jsonl"
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps({
"query_id": entry.query_id,
"query": entry.query,
"timestamp": entry.timestamp,
"sources_count": entry.sources_count,
"top_score": entry.top_score,
"total_latency_ms": entry.total_latency_ms,
"retrieval_latency_ms": entry.retrieval_latency_ms,
"generation_latency_ms": entry.generation_latency_ms,
"feedback": entry.feedback,
}, ensure_ascii=False) + "\n")
def get_daily_stats(self, date: str = None) -> dict:
"""获取每日统计"""
if date is None:
date = datetime.now().strftime('%Y%m%d')
log_file = self.log_dir / f"rag_{date}.jsonl"
if not log_file.exists():
return {"error": f"没有 {date} 的日志数据"}
entries = []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
entries.append(json.loads(line))
if not entries:
return {"total_queries": 0}
latencies = [e["total_latency_ms"] for e in entries]
scores = [e.get("top_score", 0) for e in entries]
feedbacks = [e.get("feedback", "") for e in entries]
return {
"date": date,
"total_queries": len(entries),
"avg_latency_ms": sum(latencies) / len(latencies),
"p50_latency_ms": sorted(latencies)[len(latencies) // 2],
"p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)] if len(latencies) > 1 else latencies[0],
"avg_retrieval_score": sum(scores) / len(scores) if scores else 0,
"positive_feedback_rate": feedbacks.count("positive") / len(feedbacks),
}
def detect_quality_degradation(self, window_days: int = 7) -> list[str]:
"""检测质量退化"""
alerts = []
stats_list = []
for i in range(window_days):
date = (datetime.now() - __import__('datetime').timedelta(days=i)).strftime('%Y%m%d')
stats = self.get_daily_stats(date)
if "total_queries" in stats and stats["total_queries"] > 0:
stats_list.append(stats)
if len(stats_list) < 2:
return ["数据不足,无法进行退化检测"]
# 检查延迟退化
recent_latency = stats_list[0]["avg_latency_ms"]
baseline_latency = sum(s["avg_latency_ms"] for s in stats_list[1:]) / len(stats_list[1:])
if recent_latency > baseline_latency * 1.5:
alerts.append(
f"延迟退化告警: 最近平均延迟 {recent_latency:.0f}ms, "
f"基线 {baseline_latency:.0f}ms (增长 {(recent_latency/baseline_latency-1)*100:.1f}%)"
)
# 检查检索质量退化
recent_score = stats_list[0]["avg_retrieval_score"]
baseline_score = sum(s["avg_retrieval_score"] for s in stats_list[1:]) / len(stats_list[1:])
if recent_score < baseline_score * 0.8:
alerts.append(
f"检索质量退化告警: 最近平均分数 {recent_score:.3f}, "
f"基线 {baseline_score:.3f}"
)
return alerts优点
- 准确性高: 通过外部知识检索减少模型幻觉
- 可溯源: 回答基于具体文档来源,增强可信度
- 灵活更新: 知识库更新无需重新训练模型
- 成本效益: 相比模型微调,维护成本更低
- 可扩展: 知识库可以持续扩展
缺点
- 检索延迟: 向量检索和重排序增加响应时间
- 依赖检索质量: 回答质量高度依赖检索的准确性
- 上下文窗口限制: 检索到的文档数量受上下文窗口限制
- 复杂度高: 涉及文档处理、Embedding、向量数据库等多个组件
- 冷启动问题: 需要足够的高质量文档才能获得好的效果
性能注意事项
检索延迟优化
# 1. Embedding 缓存
class CachedEmbeddingProvider(EmbeddingProvider):
"""带缓存的 Embedding 提供者"""
def __init__(self, base_provider: EmbeddingProvider, cache_size: int = 10000):
self.base_provider = base_provider
self.cache: dict[str, list[float]] = {}
self.cache_size = cache_size
def embed(self, text: str) -> list[float]:
cache_key = hashlib.md5(text.encode()).hexdigest()
if cache_key in self.cache:
return self.cache[cache_key]
embedding = self.base_provider.embed(text)
if len(self.cache) >= self.cache_size:
# 简单的 LRU: 移除最早的缓存
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
self.cache[cache_key] = embedding
return embedding
def embed_batch(self, texts: list[str]) -> list[list[float]]:
uncached = []
uncached_indices = []
for i, text in enumerate(texts):
cache_key = hashlib.md5(text.encode()).hexdigest()
if cache_key in self.cache:
continue
uncached.append(text)
uncached_indices.append(i)
if uncached:
new_embeddings = self.base_provider.embed_batch(uncached)
for text, embedding in zip(uncached, new_embeddings):
cache_key = hashlib.md5(text.encode()).hexdigest()
self.cache[cache_key] = embedding
return [self.cache[hashlib.md5(t.encode()).hexdigest()] for t in texts]
def get_dimension(self) -> int:
return self.base_provider.get_dimension()上下文窗口管理
class ContextWindowManager:
"""上下文窗口管理器"""
def __init__(self, max_tokens: int = 4096, reserved_tokens: int = 1000):
self.max_tokens = max_tokens
self.reserved_tokens = reserved_tokens
self.available_tokens = max_tokens - reserved_tokens
@staticmethod
def estimate_tokens(text: str) -> int:
"""估算文本 token 数(简化版)"""
# 中文约 1.5 token/字符, 英文约 0.75 token/word
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
english_words = len(text.split()) - chinese_chars
return int(chinese_chars * 1.5 + english_words * 1.3)
def fit_context(
self,
results: list[SearchResult],
system_prompt: str = "",
query: str = "",
) -> list[SearchResult]:
"""在上下文窗口内选择最优的文档组合"""
used_tokens = self.estimate_tokens(system_prompt) + self.estimate_tokens(query)
remaining = self.available_tokens - used_tokens
selected = []
for result in results:
doc_tokens = self.estimate_tokens(result.content)
if doc_tokens <= remaining:
selected.append(result)
remaining -= doc_tokens
else:
# 尝试截断
if remaining > 100:
# 按字符估算截断位置
chars = int(remaining / 1.5)
result.content = result.content[:chars] + "..."
selected.append(result)
break
return selected总结
RAG 生产环境的成功部署需要关注以下关键环节:
- 文档处理: 建立健壮的多格式文档解析和清洗流水线
- 分块策略: 根据文档类型选择合适的分块方法,平衡语义完整性和检索精度
- 混合检索: 结合关键词检索和向量检索,提升召回率
- 重排序优化: 通过 Reranking 提升检索结果的相关性
- 监控告警: 建立完整的监控体系,及时发现质量退化
关键知识点
| 概念 | 说明 |
|---|---|
| 文档分块 | 将长文档切分为适当大小的片段,便于检索和上下文组装 |
| Embedding | 将文本转换为高维向量,用于语义相似度计算 |
| 混合检索 | 结合关键词检索(BM25)和向量检索,提升召回率 |
| RRF | 倒数排名融合,用于合并多个检索结果列表 |
| HyDE | 假设性文档嵌入,先生成假设答案再用于检索 |
| Reranking | 对初检结果进行精细化排序,提升相关性 |
| 上下文窗口 | 模型能处理的最大 token 数,限制检索文档总量 |
常见误区
误区: 分块越大越好
- 过大的块会引入噪声,降低检索精度
- 解决: 根据内容类型调整,通常 300-800 字符为宜
误区: 向量检索足够,不需要关键词检索
- 向量检索对精确匹配(如产品型号、专有名词)效果较差
- 解决: 使用混合检索,兼顾语义和精确匹配
误区: 只用 Top-1 结果生成回答
- 单一检索结果可能不完整或有偏差
- 解决: 使用 Top-K 多结果,结合重排序选择最佳上下文
误区: 忽略文档质量
- 垃圾输入产生垃圾输出,文档质量直接影响 RAG 效果
- 解决: 建立文档质量检查机制,过滤低质量内容
进阶路线
- 入门: 理解 RAG 原理,搭建基础检索生成流程
- 进阶: 实现混合检索、重排序、查询扩展
- 高级: 构建生产级监控、调试和质量评估体系
- 专家: 自研 Embedding 模型、设计多跳检索和知识图谱增强 RAG
适用场景
- 企业知识库问答系统
- 技术文档检索与问答
- 法律/医疗等专业领域知识检索
- 客服机器人
- 内部 wiki 智能搜索
落地建议
- 先跑通再优化: 先用最简单的方案验证业务可行性
- 关注数据质量: 知识库质量比算法优化更重要
- 持续收集反馈: 建立用户反馈机制,持续迭代
- 监控检索质量: 定期检查检索分数分布,发现退化
- A/B 测试分块策略: 不同文档类型可能需要不同的分块参数
排错清单
| 问题 | 可能原因 | 解决方案 |
|---|---|---|
| 回答与问题无关 | 检索质量差,返回了不相关文档 | 检查 Embedding 模型,增加重排序 |
| 回答"不知道" | 检索未命中相关文档 | 扩展查询,调整分块大小,增加检索数量 |
| 回答包含错误信息 | 检索到过时或矛盾的文档 | 建立文档更新机制,增加时效性权重 |
| 响应延迟高 | Embedding 计算慢或重排序耗时 | 使用缓存,预计算 Embedding,优化重排序 |
| 上下文溢出 | 检索文档过多或过长 | 使用上下文窗口管理器,动态裁剪 |
| 文档解析乱码 | 编码问题或格式不支持 | 检查编码,增加格式预处理 |
复盘问题
- 上周 RAG 系统的平均检索分数是多少?是否有退化趋势?
- 用户反馈中"回答不准确"的占比是多少?主要原因是什么?
- 文档分块的平均大小是否合理?是否有过大或过小的块?
- 混合检索中,关键词检索和向量检索各自的命中贡献率是多少?
- 重排序后结果的质量提升幅度是多少?
延伸阅读
- LangChain RAG 文档
- LlamaIndex - RAG 框架
- BGE Embedding - 中文 Embedding 模型
- Rank-BM25 - BM25 实现
- ChromaDB - 向量数据库
- DSPy - 斯坦福 RAG 编程框架
- Advanced RAG Techniques - 高级 RAG 技术综述
