LangChain RAG 实战:从零构建企业级智能文档问答系统

基于最新技术栈,实现高效、准确的文档智能问答

在信息爆炸的时代,企业积累了海量文档却难以高效利用。传统的全文检索技术已经无法满足精准问答的需求,而基于检索增强生成(Retrieval-Augmented Generation, RAG)的技术方案正成为解决这一痛点的最佳实践。本文将深入讲解如何使用 LangChain 构建一个完整的 RAG 系统,通过详细的代码示例和技术拆解,协助中高级开发者掌握这一核心技术。

1 RAG 技术架构概述

1.1 什么是 RAG 技术

检索增强生成(RAG)是一种将检索系统与生成模型相结合的技术架构。与传统生成模型不同,RAG 在生成答案前会先从外部知识库中检索相关文档片段,然后将这些信息与原始问题一起提供给生成模型,从而生成更准确、更可靠的回答。

1.2 LangChain 在 RAG 中的角色

LangChain 是一个用于构建大语言模型应用的框架,它在 RAG 系统中扮演着流程编排者的角色:

  • 文档处理流水线:统一管理文档的加载、分割和向量化
  • 智能检索编排:协调检索器、向量数据库和大模型之间的交互
  • 模块化设计:提供可插拔的组件,便于定制和扩

LangChain RAG 实战:从零构建企业级智能文档问答系统

2 环境准备与项目搭建

2.1 环境配置

第一,确保安装了必要的 Python 包:

# 创建并激活虚拟环境
python -m venv rag_env
source rag_env/bin/activate  # Linux/Mac
# 或 rag_envScriptsactivate  # Windows

# 安装核心依赖
pip install langchain langchain-community langchain-openai chromadb tiktoken
pip install pypdf python-dotenv  # 文档处理和配置管理
pip install sentence-transformers  # 本地嵌入模型
pip install faiss-cpu  # Facebook AI 类似度搜索

2.2 配置 API 密钥

创建 .env 文件管理敏感信息:

# config.py - 配置管理模块
import os
from dotenv import load_dotenv

load_dotenv()

class Config:
    # OpenAI 配置
    OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
    OPENAI_MODEL = "gpt-3.5-turbo"  # 或 "gpt-4"
    
    # 向量数据库配置
    VECTOR_STORE_PATH = "./data/vector_store"
    
    # 本地嵌入模型配置
    LOCAL_EMBEDDING_MODEL = "all-MiniLM-L6-v2"
    
    # 文档处理配置
    CHUNK_SIZE = 500
    CHUNK_OVERLAP = 50
    
    @classmethod
    def validate(cls):
        if not cls.OPENAI_API_KEY:
            raise ValueError("OPENAI_API_KEY 未设置")
        print("✓ 配置验证通过")

3 完整 RAG 系统实现

3.1 文档加载与处理模块

# document_processor.py - 文档处理器
import os
from typing import List, Dict, Any
from langchain_community.document_loaders import (
    PyPDFLoader,
    TextLoader,
    CSVLoader,
    UnstructuredMarkdownLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document

class DocumentProcessor:
    """文档处理流水线"""
    
    def __init__(self, chunk_size=500, chunk_overlap=50):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
            separators=["

", "
", "。", "!", "?", ";", ",", " ", ""]
        )
    
    def load_documents(self, file_paths: List[str]) -> List[Document]:
        """加载多种格式的文档"""
        all_documents = []
        
        for file_path in file_paths:
            if not os.path.exists(file_path):
                print(f"警告: 文件不存在 {file_path}")
                continue
            
            loader = self._get_loader(file_path)
            if loader:
                try:
                    documents = loader.load()
                    print(f"✓ 已加载 {file_path}: {len(documents)} 个文档")
                    all_documents.extend(documents)
                except Exception as e:
                    print(f"✗ 加载失败 {file_path}: {str(e)}")
        
        return all_documents
    
    def _get_loader(self, file_path: str):
        """根据文件扩展名选择加载器"""
        ext = os.path.splitext(file_path)[1].lower()
        
        loader_map = {
            '.pdf': PyPDFLoader,
            '.txt': TextLoader,
            '.csv': CSVLoader,
            '.md': UnstructuredMarkdownLoader,
        }
        
        loader_class = loader_map.get(ext)
        if loader_class:
            return loader_class(file_path)
        return None
    
    def split_documents(self, documents: List[Document]) -> List[Document]:
        """分割文档为适合处理的片段"""
        if not documents:
            return []
        
        # 添加元数据增强
        for doc in documents:
            if not doc.metadata:
                doc.metadata = {}
            doc.metadata['source'] = doc.metadata.get('source', 'unknown')
            doc.metadata['chunk_id'] = hash(doc.page_content) % 10000
        
        # 执行分割
        split_docs = self.text_splitter.split_documents(documents)
        print(f"文档分割完成: {len(documents)} -> {len(split_docs)} 个片段")
        
        # 添加顺序索引
        for i, doc in enumerate(split_docs):
            doc.metadata['chunk_index'] = i
        
        return split_docs
    
    def create_document_chunks(self, file_paths: List[str]) -> List[Document]:
        """完整文档处理流水线"""
        print("开始文档处理流程...")
        
        # 1. 加载文档
        raw_docs = self.load_documents(file_paths)
        if not raw_docs:
            raise ValueError("未加载到任何文档")
        
        # 2. 分割文档
        chunked_docs = self.split_documents(raw_docs)
        
        # 3. 质量检查
        self._quality_check(chunked_docs)
        
        return chunked_docs
    
    def _quality_check(self, documents: List[Document]):
        """检查分割质量"""
        if not documents:
            return
        
        avg_length = sum(len(doc.page_content) for doc in documents) / len(documents)
        min_length = min(len(doc.page_content) for doc in documents)
        max_length = max(len(doc.page_content) for doc in documents)
        
        print(f"质量报告:")
        print(f"  - 片段数量: {len(documents)}")
        print(f"  - 平均长度: {avg_length:.1f} 字符")
        print(f"  - 最短长度: {min_length} 字符")
        print(f"  - 最长长度: {max_length} 字符")
        
        # 检查是否有过短的片段
        short_chunks = [doc for doc in documents if len(doc.page_content) < 20]
        if short_chunks:
            print(f"  - 警告: 发现 {len(short_chunks)} 个过短片段")

# 使用示例
if __name__ == "__main__":
    processor = DocumentProcessor(chunk_size=500, chunk_overlap=50)
    
    # 假设我们有这些文档
    sample_files = [
        "./docs/technical_guide.pdf",
        "./docs/api_reference.md",
        "./docs/faq.txt"
    ]
    
    # 创建模拟文件用于测试
    import tempfile
    with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
        f.write("这是测试文档内容。" * 50)
        sample_files = [f.name]
    
    try:
        chunks = processor.create_document_chunks(sample_files)
        print(f"
前3个片段预览:")
        for i, chunk in enumerate(chunks[:3]):
            print(f"
--- 片段 {i+1} ---")
            print(f"内容: {chunk.page_content[:100]}...")
            print(f"元数据: {chunk.metadata}")
    finally:
        os.unlink(sample_files[0])

3.2 向量化与存储模块

# vector_store.py - 向量存储管理器
import os
import hashlib
from typing import List, Optional, Tuple
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma, FAISS
from langchain.schema import Document
import pickle

class VectorStoreManager:
    """向量存储管理器,支持多种向量数据库"""
    
    def __init__(self, config):
        self.config = config
        self.embeddings = self._init_embeddings()
        self.vector_store = None
    
    def _init_embeddings(self):
        """初始化嵌入模型"""
        # 选项1: 使用 OpenAI 嵌入(需要 API 密钥)
        if self.config.OPENAI_API_KEY:
            print("使用 OpenAI 嵌入模型")
            return OpenAIEmbeddings(
                model="text-embedding-ada-002",
                openai_api_key=self.config.OPENAI_API_KEY
            )
        
        # 选项2: 使用本地 HuggingFace 模型
        else:
            print(f"使用本地嵌入模型: {self.config.LOCAL_EMBEDDING_MODEL}")
            return HuggingFaceEmbeddings(
                model_name=f"sentence-transformers/{self.config.LOCAL_EMBEDDING_MODEL}",
                model_kwargs={'device': 'cpu'},
                encode_kwargs={'normalize_embeddings': True}
            )
    
    def create_vector_store(self, 
                           documents: List[Document], 
                           store_type: str = "chroma",
                           persist: bool = True) -> None:
        """创建向量存储"""
        
        if not documents:
            raise ValueError("文档列表为空")
        
        print(f"开始创建向量存储 ({store_type})...")
        print(f"处理 {len(documents)} 个文档片段")
        
        # 计算文档指纹用于缓存
        docs_hash = self._calculate_documents_hash(documents)
        cache_path = f"{self.config.VECTOR_STORE_PATH}_{store_type}_{docs_hash}.pkl"
        
        # 检查缓存
        if persist and os.path.exists(cache_path):
            print("发现缓存,加载已有向量存储...")
            self._load_from_cache(cache_path)
            return
        
        # 创建新的向量存储
        if store_type == "chroma":
            self.vector_store = Chroma.from_documents(
                documents=documents,
                embedding=self.embeddings,
                persist_directory=self.config.VECTOR_STORE_PATH if persist else None
            )
        elif store_type == "faiss":
            self.vector_store = FAISS.from_documents(
                documents=documents,
                embedding=self.embeddings
            )
            
            if persist:
                self.vector_store.save_local(self.config.VECTOR_STORE_PATH)
        else:
            raise ValueError(f"不支持的存储类型: {store_type}")
        
        print("✓ 向量存储创建完成")
        
        # 保存缓存
        if persist:
            self._save_to_cache(cache_path)
    
    def _calculate_documents_hash(self, documents: List[Document]) -> str:
        """计算文档集合的哈希值"""
        content = "".join(doc.page_content for doc in documents[:100])  # 采样前100个
        return hashlib.md5(content.encode()).hexdigest()[:8]
    
    def _save_to_cache(self, cache_path: str):
        """保存向量存储到缓存"""
        try:
            # Chroma 有自己的持久化机制,这里只保存元数据
            metadata = {
                "store_type": "chroma" if isinstance(self.vector_store, Chroma) else "faiss",
                "document_count": self.vector_store._collection.count() if hasattr(self.vector_store, '_collection') else "unknown"
            }
            
            with open(cache_path, 'wb') as f:
                pickle.dump(metadata, f)
            print(f"缓存已保存: {cache_path}")
        except Exception as e:
            print(f"缓存保存失败: {e}")
    
    def _load_from_cache(self, cache_path: str):
        """从缓存加载向量存储"""
        try:
            if not os.path.exists(self.config.VECTOR_STORE_PATH):
                print("缓存元数据存在,但向量存储文件不存在")
                return False
            
            # 对于 Chroma,直接加载持久化目录
            with open(cache_path, 'rb') as f:
                metadata = pickle.load(f)
            
            if metadata["store_type"] == "chroma":
                self.vector_store = Chroma(
                    persist_directory=self.config.VECTOR_STORE_PATH,
                    embedding_function=self.embeddings
                )
            else:
                self.vector_store = FAISS.load_local(
                    self.config.VECTOR_STORE_PATH,
                    self.embeddings,
                    allow_dangerous_deserialization=True
                )
            
            print(f"✓ 从缓存加载向量存储,包含 {metadata.get('document_count', 'unknown')} 个文档")
            return True
        except Exception as e:
            print(f"缓存加载失败: {e}")
            return False
    
    def similarity_search(self, 
                         query: str, 
                         k: int = 4,
                         score_threshold: float = 0.0) -> List[Tuple[Document, float]]:
        """类似度搜索"""
        if not self.vector_store:
            raise ValueError("向量存储未初始化")
        
        # 根据向量存储类型调用不同的搜索方法
        if isinstance(self.vector_store, Chroma):
            results = self.vector_store.similarity_search_with_relevance_scores(
                query=query,
                k=k
            )
        else:  # FAISS
            results = self.vector_store.similarity_search_with_score(
                query=query,
                k=k
            )
        
        # 过滤低于阈值的結果
        filtered_results = [(doc, score) for doc, score in results if score >= score_threshold]
        
        print(f"搜索查询: '{query}'")
        print(f"返回 {len(filtered_results)}/{len(results)} 个相关片段")
        
        if filtered_results:
            print("最佳匹配:")
            print(f"  分数: {filtered_results[0][1]:.3f}")
            print(f"  内容: {filtered_results[0][0].page_content[:100]}...")
        
        return filtered_results
    
    def get_retriever(self, search_type: str = "similarity", **kwargs):
        """获取检索器对象"""
        if not self.vector_store:
            raise ValueError("向量存储未初始化")
        
        # 配置检索器参数
        search_kwargs = {"k": kwargs.get("k", 4)}
        
        if search_type == "similarity_score_threshold":
            search_kwargs["score_threshold"] = kwargs.get("score_threshold", 0.5)
        
        # 创建检索器
        retriever = self.vector_store.as_retriever(
            search_type=search_type,
            search_kwargs=search_kwargs
        )
        
        return retriever

# 使用示例
if __name__ == "__main__":
    from config import Config
    Config.validate()
    
    # 创建示例文档
    sample_docs = [
        Document(
            page_content="LangChain 是一个用于构建大语言模型应用的框架",
            metadata={"source": "guide", "page": 1}
        ),
        Document(
            page_content="RAG 技术结合了检索系统和生成模型的优势",
            metadata={"source": "paper", "page": 2}
        ),
        Document(
            page_content="向量数据库用于高效存储和检索嵌入向量",
            metadata={"source": "tutorial", "page": 3}
        )
    ]
    
    # 创建向量存储
    manager = VectorStoreManager(Config)
    
    # 使用 Chroma
    manager.create_vector_store(
        documents=sample_docs,
        store_type="chroma",
        persist=True
    )
    
    # 测试搜索
    results = manager.similarity_search("什么是 LangChain?", k=2)
    
    # 获取检索器
    retriever = manager.get_retriever(
        search_type="similarity",
        k=3
    )
    print(f"检索器类型: {type(retriever)}")

3.3 RAG 问答链实现

# rag_chain.py - RAG 问答链核心实现
import logging
from typing import List, Dict, Any, Optional
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.schema.runnable import RunnablePassthrough, RunnableParallel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema import Document
from langchain.callbacks import get_openai_callback

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class RAGChain:
    """RAG 问答链核心类"""
    
    def __init__(self, config, retriever=None):
        self.config = config
        self.retriever = retriever
        self.llm = self._init_llm()
        self.chain = None
        self._setup_chain()
    
    def _init_llm(self):
        """初始化语言模型"""
        return ChatOpenAI(
            model_name=self.config.OPENAI_MODEL,
            temperature=0.1,  # 低温度以获得更确定的回答
            max_tokens=1000,
            openai_api_key=self.config.OPENAI_API_KEY,
            streaming=False  # 非流式响应,便于调试
        )
    
    def _setup_chain(self):
        """设置 RAG 链"""
        if not self.retriever:
            logger.warning("未设置检索器,使用空检索器")
            self.retriever = lambda x: []
        
        # 1. 定义提示模板
        prompt_template = self._create_prompt_template()
        
        # 2. 构建处理链
        self.chain = (
            RunnableParallel({
                "context": self.retriever,
                "question": RunnablePassthrough()
            })
            .assign(answer=self._create_qa_chain(prompt_template))
            .pick(["answer", "context"])
        )
    
    def _create_prompt_template(self) -> ChatPromptTemplate:
        """创建提示模板"""
        
        # 系统提示词
        system_prompt = """你是一个专业的AI助手,基于提供的上下文信息回答问题。
        
        请遵循以下规则:
        1. 只使用提供的上下文信息回答问题
        2. 如果上下文信息不足,请诚实说明
        3. 保持回答准确、简洁、专业
        4. 使用中文回答,除非问题明确要求其他语言
        
        上下文信息:
        {context}
        """
        
        # 用户问题模板
        human_prompt = "问题:{question}"
        
        # 创建提示模板
        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("human", human_prompt)
        ])
        
        return prompt
    
    def _create_qa_chain(self, prompt_template: ChatPromptTemplate):
        """创建问答链"""
        return (
            {"context": lambda x: self._format_documents(x["context"]), 
             "question": lambda x: x["question"]}
            | prompt_template
            | self.llm
            | StrOutputParser()
        )
    
    def _format_documents(self, documents: List[Document]) -> str:
        """格式化文档为字符串"""
        if not documents:
            return "没有找到相关上下文信息。"
        
        formatted = []
        for i, doc in enumerate(documents):
            # 添加来源信息
            source = doc.metadata.get('source', '未知来源')
            page = doc.metadata.get('page', '未知页码')
            
            content = doc.page_content.strip()
            formatted.append(f"[文档 {i+1}, 来源: {source}, 页码: {page}]
{content}")
        
        return "

".join(formatted)
    
    def query(self, question: str, verbose: bool = True) -> Dict[str, Any]:
        """执行查询"""
        if not self.chain:
            raise ValueError("问答链未初始化")
        
        try:
            with get_openai_callback() as cb:
                # 执行查询
                result = self.chain.invoke(question)
                
                # 记录使用情况
                if verbose:
                    logger.info(f"查询完成")
                    logger.info(f"Token 使用: {cb.total_tokens}")
                    logger.info(f"成本: ${cb.total_cost:.4f}")
                    logger.info(f"检索到 {len(result.get('context', []))} 个相关片段")
                
                # 添加元数据
                result['tokens_used'] = cb.total_tokens
                result['cost'] = cb.total_cost
                
                return result
                
        except Exception as e:
            logger.error(f"查询失败: {str(e)}")
            return {
                "answer": f"查询过程中出现错误: {str(e)}",
                "context": [],
                "tokens_used": 0,
                "cost": 0
            }
    
    def batch_query(self, questions: List[str]) -> List[Dict[str, Any]]:
        """批量查询"""
        results = []
        total_tokens = 0
        total_cost = 0
        
        for i, question in enumerate(questions, 1):
            logger.info(f"处理问题 {i}/{len(questions)}: {question[:50]}...")
            
            result = self.query(question, verbose=False)
            results.append(result)
            
            total_tokens += result.get('tokens_used', 0)
            total_cost += result.get('cost', 0)
            
            # 添加延迟避免速率限制
            import time
            time.sleep(0.5)
        
        logger.info(f"批量查询完成")
        logger.info(f"总 Token 使用: {total_tokens}")
        logger.info(f"总成本: ${total_cost:.4f}")
        
        return results

class AdvancedRAGChain(RAGChain):
    """增强版 RAG 链,支持更多高级功能"""
    
    def __init__(self, config, retriever=None):
        super().__init__(config, retriever)
        self.history = []
    
    def _create_prompt_template(self) -> ChatPromptTemplate:
        """创建增强提示模板"""
        
        system_prompt = """你是一个专业的AI助手,基于提供的上下文信息回答问题。
        
        请遵循以下规则:
        1. 仔细分析上下文信息与问题的相关性
        2. 如果上下文信息不足或完全不相关,请明确说明
        3. 从多个相关片段中综合信息
        4. 必要时指出信息的来源
        5. 使用清晰、结构化的格式回答
        6. 如果问题涉及多步推理,展示思考过程
        
        上下文信息:
        {context}
        """
        
        human_prompt = """问题:{question}

请基于以上上下文信息,提供最准确的回答。"""
        
        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("human", human_prompt)
        ])
        
        return prompt
    
    def query_with_citations(self, question: str) -> Dict[str, Any]:
        """带引用的查询"""
        result = self.query(question, verbose=False)
        
        # 添加引用信息
        if result.get('context'):
            citations = []
            for doc in result['context']:
                source = doc.metadata.get('source', '未知')
                page = doc.metadata.get('page', '')
                citation = f"[{source}"
                if page:
                    citation += f" p.{page}"
                citation += "]"
                citations.append(citation)
            
            result['citations'] = citations
            result['answer'] = f"{result['answer']}

参考文献: {', '.join(set(citations))}"
        
        return result
    
    def _format_documents(self, documents: List[Document]) -> str:
        """格式化文档并添加相关性标记"""
        if not documents:
            return "没有找到相关上下文信息。"
        
        formatted = []
        for i, doc in enumerate(documents):
            source = doc.metadata.get('source', '未知来源')
            page = doc.metadata.get('page', '未知页码')
            
            # 添加相关性评分(如果有)
            score = doc.metadata.get('score', 0)
            score_marker = f" [相关性: {score:.3f}]" if score > 0 else ""
            
            content = doc.page_content.strip()
            formatted.append(
                f"[片段 {i+1}{score_marker}, 来源: {source}, 页码: {page}]
{content}"
            )
        
        return "

".join(formatted)

# 使用示例
if __name__ == "__main__":
    from config import Config
    from vector_store import VectorStoreManager
    
    Config.validate()
    
    # 创建示例文档
    sample_docs = [
        Document(
            page_content="LangChain 是一个用于构建大语言模型应用的框架,提供了模块化的组件和工具。",
            metadata={"source": "官方文档", "page": 1, "score": 0.95}
        ),
        Document(
            page_content="RAG(检索增强生成)技术通过结合检索系统和生成模型,提高了回答的准确性和可靠性。",
            metadata={"source": "技术论文", "page": 2, "score": 0.88}
        ),
    ]
    
    # 创建向量存储和检索器
    vs_manager = VectorStoreManager(Config)
    vs_manager.create_vector_store(sample_docs, store_type="chroma", persist=False)
    retriever = vs_manager.get_retriever(k=2)
    
    # 创建 RAG 链
    rag_chain = AdvancedRAGChain(Config, retriever)
    
    # 测试查询
    questions = [
        "LangChain 是什么?",
        "RAG 技术有什么优势?"
    ]
    
    print("=== 测试 RAG 问答链 ===
")
    
    for question in questions:
        print(f"问题: {question}")
        result = rag_chain.query_with_citations(question)
        
        print(f"回答: {result['answer']}")
        print(f"使用 Token: {result.get('tokens_used', 0)}")
        print(f"引用: {result.get('citations', [])}")
        print("-" * 50)

3.4 完整应用集成

# main.py - 完整的 RAG 应用
import os
import argparse
from typing import List
from config import Config
from document_processor import DocumentProcessor
from vector_store import VectorStoreManager
from rag_chain import AdvancedRAGChain

class RAGApplication:
    """完整的 RAG 应用"""
    
    def __init__(self, config_path: str = None):
        if config_path:
            os.environ["ENV_FILE"] = config_path
        
        self.config = Config()
        self.config.validate()
        
        self.processor = None
        self.vector_manager = None
        self.rag_chain = None
        self.is_initialized = False
    
    def initialize(self, 
                  document_paths: List[str] = None,
                  force_rebuild: bool = False):
        """初始化 RAG 系统"""
        print("=" * 60)
        print("初始化 RAG 系统")
        print("=" * 60)
        
        # 1. 处理文档
        print("
[1/3] 处理文档...")
        self.processor = DocumentProcessor(
            chunk_size=self.config.CHUNK_SIZE,
            chunk_overlap=self.config.CHUNK_OVERLAP
        )
        
        if document_paths:
            documents = self.processor.create_document_chunks(document_paths)
        else:
            # 使用示例文档
            documents = self._create_sample_documents()
        
        # 2. 创建向量存储
        print("
[2/3] 创建向量存储...")
        self.vector_manager = VectorStoreManager(self.config)
        
        if force_rebuild or not os.path.exists(self.config.VECTOR_STORE_PATH):
            self.vector_manager.create_vector_store(
                documents=documents,
                store_type="chroma",
                persist=True
            )
        else:
            print("使用现有的向量存储")
            self.vector_manager.vector_store = self.vector_manager._load_from_cache(
                f"{self.config.VECTOR_STORE_PATH}_cache.pkl"
            )
        
        # 3. 创建 RAG 链
        print("
[3/3] 创建 RAG 问答链...")
        retriever = self.vector_manager.get_retriever(
            search_type="similarity",
            k=4
        )
        
        self.rag_chain = AdvancedRAGChain(self.config, retriever)
        
        self.is_initialized = True
        print("
✓ RAG 系统初始化完成")
    
    def _create_sample_documents(self):
        """创建示例文档"""
        from langchain.schema import Document
        
        sample_texts = [
            "LangChain 是一个用于构建大语言模型应用的框架。它提供了一套工具、组件和接口,简化了基于大语言模型的应用程序开发过程。",
            "RAG(Retrieval-Augmented Generation)技术通过结合检索系统和生成模型,使大语言模型能够访问和利用外部知识库,从而生成更准确、更可靠的回答。",
            "向量数据库(如 Chroma、FAISS)专门用于高效存储和检索高维向量。它们使用近似最近邻(ANN)算法来快速找到类似的向量。",
            "嵌入模型(如 OpenAI Embeddings、Sentence Transformers)将文本转换为高维向量表明,使得语义类似的文本在向量空间中也相近。",
            "LangChain 的主要组件包括:模型(LLMs、Chat Models)、提示(Prompts)、链(Chains)、代理(Agents)和记忆(Memory)。",
            "RAG 系统的工作流程包括:文档处理(加载、分割)、向量化(创建嵌入)、存储(向量数据库)、检索(类似度搜索)和生成(LLM 回答)。",
        ]
        
        documents = []
        for i, text in enumerate(sample_texts):
            documents.append(Document(
                page_content=text,
                metadata={
                    "source": "示例文档",
                    "page": i + 1,
                    "category": ["framework", "rag", "vector_db", "embeddings"][i % 4]
                }
            ))
        
        return documents
    
    def query(self, question: str, verbose: bool = True) -> dict:
        """查询 RAG 系统"""
        if not self.is_initialized:
            raise RuntimeError("RAG 系统未初始化,请先调用 initialize() 方法")
        
        print(f"
查询: {question}")
        print("-" * 40)
        
        result = self.rag_chain.query_with_citations(question)
        
        if verbose:
            print(f"回答: {result['answer']}")
            
            if result.get('context'):
                print(f"
参考文档 ({len(result['context'])} 个片段):")
                for i, doc in enumerate(result['context'], 1):
                    print(f"  {i}. [{doc.metadata.get('source', '未知')}]")
                    print(f"     {doc.page_content[:80]}...")
        
        return result
    
    def interactive_mode(self):
        """交互式查询模式"""
        if not self.is_initialized:
            print("正在初始化 RAG 系统(使用示例文档)...")
            self.initialize()
        
        print("
" + "=" * 60)
        print("RAG 交互式问答系统")
        print("输入 'quit' 或 'exit' 退出")
        print("输入 'reset' 重新初始化系统")
        print("=" * 60)
        
        while True:
            try:
                question = input("
请输入问题: ").strip()
                
                if question.lower() in ['quit', 'exit', 'q']:
                    print("感谢使用,再见!")
                    break
                
                elif question.lower() == 'reset':
                    confirm = input("确定要重新初始化系统吗?(y/n): ")
                    if confirm.lower() == 'y':
                        self.initialize(force_rebuild=True)
                    continue
                
                elif not question:
                    continue
                
                # 执行查询
                self.query(question, verbose=True)
                
            except KeyboardInterrupt:
                print("

程序被中断")
                break
            except Exception as e:
                print(f"错误: {str(e)}")

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="RAG 问答系统")
    parser.add_argument("--mode", choices=["interactive", "single"], 
                       default="interactive", help="运行模式")
    parser.add_argument("--question", type=str, help="单个问题(single 模式使用)")
    parser.add_argument("--documents", type=str, nargs="+", 
                       help="文档路径列表")
    parser.add_argument("--rebuild", action="store_true", 
                       help="强制重建向量存储")
    parser.add_argument("--config", type=str, default=".env", 
                       help="配置文件路径")
    
    args = parser.parse_args()
    
    # 创建应用
    app = RAGApplication(args.config)
    
    # 初始化
    app.initialize(
        document_paths=args.documents,
        force_rebuild=args.rebuild
    )
    
    # 运行模式
    if args.mode == "interactive":
        app.interactive_mode()
    elif args.mode == "single" and args.question:
        result = app.query(args.question)
        print(f"
最终回答: {result['answer']}")
    else:
        parser.print_help()

if __name__ == "__main__":
    main()

4 关键技术深度解析

4.1 文本分块策略详解

文本分块是 RAG 系统的关键环节,不同的分块策略会影响检索质量:

# chunking_strategies.py - 分块策略比较
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    CharacterTextSplitter,
    TokenTextSplitter,
    MarkdownHeaderTextSplitter,
    SentenceTransformersTokenTextSplitter
)

class ChunkingStrategies:
    """分块策略比较"""
    
    @staticmethod
    def compare_strategies(text: str):
        """比较不同的分块策略"""
        
        strategies = {
            "递归字符分块": RecursiveCharacterTextSplitter(
                chunk_size=500,
                chunk_overlap=50,
                separators=["

", "
", "。", "!", "?", ";", ",", " ", ""]
            ),
            "字符分块": CharacterTextSplitter(
                chunk_size=500,
                chunk_overlap=50,
                separator="
"
            ),
            "Token分块": TokenTextSplitter(
                chunk_size=500,
                chunk_overlap=50
            ),
        }
        
        results = {}
        for name, splitter in strategies.items():
            chunks = splitter.split_text(text)
            results[name] = {
                "chunk_count": len(chunks),
                "avg_length": sum(len(c) for c in chunks) / len(chunks) if chunks else 0,
                "sample_chunk": chunks[0][:100] + "..." if chunks else ""
            }
        
        return results

4.2 检索增强技术

高级检索技术可以显著提升 RAG 性能:

# advanced_retrieval.py - 高级检索技术
from typing import List, Dict
from langchain.retrievers import (
    ContextualCompressionRetriever,
    EnsembleRetriever,
    MultiQueryRetriever
)
from langchain.retrievers.document_compressors import LLMChainExtractor

class AdvancedRetrieval:
    """高级检索技术"""
    
    def __init__(self, base_retriever, llm):
        self.base_retriever = base_retriever
        self.llm = llm
    
    def create_compression_retriever(self):
        """创建上下文压缩检索器"""
        compressor = LLMChainExtractor.from_llm(self.llm)
        return ContextualCompressionRetriever(
            base_compressor=compressor,
            base_retriever=self.base_retriever
        )
    
    def create_multi_query_retriever(self):
        """创建多查询检索器"""
        return MultiQueryRetriever.from_llm(
            retriever=self.base_retriever,
            llm=self.llm
        )

5 性能优化与监控

5.1 性能监控装饰器

# performance_monitor.py - 性能监控
import time
import functools
from typing import Callable, Any
import psutil
import os

def monitor_performance(func: Callable) -> Callable:
    """性能监控装饰器"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # 记录开始时间
        start_time = time.time()
        
        # 记录开始时的内存使用
        process = psutil.Process(os.getpid())
        start_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        try:
            result = func(*args, **kwargs)
            
            # 计算执行时间
            execution_time = time.time() - start_time
            
            # 计算内存使用
            end_memory = process.memory_info().rss / 1024 / 1024
            memory_used = end_memory - start_memory
            
            print(f"性能统计 - {func.__name__}:")
            print(f"  执行时间: {execution_time:.2f} 秒")
            print(f"  内存使用: {memory_used:.2f} MB")
            
            return result
            
        except Exception as e:
            execution_time = time.time() - start_time
            print(f"错误 - {func.__name__}: {str(e)} (耗时: {execution_time:.2f}s)")
            raise
    
    return wrapper

6 总结与最佳实践

6.1 技术要点总结

通过本文的完整实现,我们构建了一个企业级的 RAG 系统,涵盖以下关键技术:

  1. 模块化架构设计:清晰的文档处理、向量存储、问答链分离
  2. 灵活的分块策略:支持多种分块算法,适应不同文档类型
  3. 多向量数据库支持:集成 Chroma 和 FAISS,各有优势
  4. 智能检索优化:实现类似度检索、多查询扩展等技术
  5. 完整的应用集成:提供 CLI 和交互式界面

6.2 性能优化提议

根据实践经验,以下优化措施可以显著提升 RAG 系统性能:

优化方向

具体措施

预期效果

分块优化

根据文档类型动态调整分块大小和重叠

提升 20-30% 检索准确率

检索优化

使用多查询扩展 + 重排序

提升 15-25% 回答质量

缓存策略

实现查询结果缓存和向量存储缓存

减少 40-60% 响应时间

并行处理

批量文档处理的并行化

提升 3-5 倍处理速度

6.3 扩展方向

对于中高级开发者,可以思考以下扩展方向:

  1. 多模态 RAG:支持图像、表格等非文本内容
  2. 实时更新:实现增量学习和实时知识库更新
  3. 混合检索:结合语义检索和关键词检索
  4. 智能路由:根据问题类型选择不同的检索策略
  5. 联邦学习:在保护隐私的前提下利用多方知识

6.4 关键代码片段回顾

以下是 RAG 系统的核心代码逻辑总结:

# RAG 系统核心逻辑
class RAGSystem:
    def __init__(self):
        self.pipeline = self._build_pipeline()
    
    def _build_pipeline(self):
        """构建处理流水线"""
        return (
            # 1. 文档加载与处理
            self.load_documents
            | self.split_documents
            # 2. 向量化与存储
            | self.create_embeddings
            | self.store_in_vector_db
            # 3. 检索与生成
            | self.retrieve_relevant_chunks
            | self.generate_answer
        )
    
    def query(self, question: str):
        """端到端查询"""
        return self.pipeline(question)

6.5 部署提议

对于生产环境部署,提议:

  1. 容器化部署:使用 Docker 封装整个应用
  2. 水平扩展:对检索服务和生成服务分别扩展
  3. 监控告警:实现全面的性能监控和异常告警
  4. A/B 测试:对比不同检索策略和模型的效果
  5. 成本控制:监控 API 使用成本,设置使用上限

本文提供的 RAG 实现方案具有高度的可扩展性和可定制性,开发者可以根据具体业务需求调整和优化各个组件。通过深入理解每个模块的工作原理,中高级开发者可以构建出适合自己业务场景的高效 RAG 系统。

© 版权声明
THE END
如果内容对您有所帮助,就支持一下吧!
点赞0 分享
评论 抢沙发

请登录后发表评论

    暂无评论内容