Skip to content

功能说明

  1. 文档上传与处理
    • 支持PDF/TXT格式
    • 自动处理多文件
    • 可配置分块参数
  2. 检索增强功能
    • 使用FAISS向量数据库
    • Hugging Face句子嵌入
    • 可配置检索结果数量
  3. 生成功能
    • 使用FLAN-T5基础模型
    • 可调节生成温度
    • 结合检索上下文生成答案
  4. 界面功能
    • 实时问答交互
    • 显示参考文档片段
    • 参数侧边栏配置
    • 处理状态提示

image-20250410221459431

rag_app.py

重页面轻功能

python
import streamlit as st
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import os

# 初始化模型
@st.cache_resource
def load_models():
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
    
    # 使用一个小型的开源生成模型(示例使用FLAN-T5)
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
    model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
    pipe = pipeline(
        "text2text-generation",
        model=model,
        tokenizer=tokenizer,
        max_length=512,
        temperature=0.3
    )
    llm = HuggingFacePipeline(pipeline=pipe)
    
    return embeddings, llm

def main():
    st.title("RAG 问答系统 💬")
    
    # 初始化session state
    if "vector_store" not in st.session_state:
        st.session_state.vector_store = None
    
    # 侧边栏设置
    with st.sidebar:
        st.header("配置")
        uploaded_files = st.file_uploader(
            "上传文档(PDF/TXT)",
            type=["pdf", "txt"],
            accept_multiple_files=True
        )
        
        chunk_size = st.number_input("分块大小", 100, 2000, 500)
        chunk_overlap = st.number_input("分块重叠", 0, 200, 50)
        search_k = st.slider("检索结果数量", 1, 5, 3)
        temperature = st.slider("生成温度", 0.0, 1.0, 0.3)
        
        process_button = st.button("处理文档")
    
    # 模型加载
    embeddings, llm = load_models()
    
    # 文档处理
    if process_button and uploaded_files:
        with st.spinner("处理文档中..."):
            docs = []
            for file in uploaded_files:
                # 保存临时文件
                temp_path = f"temp_{file.name}"
                with open(temp_path, "wb") as f:
                    f.write(file.getbuffer())
                
                # 根据文件类型加载
                if file.name.endswith(".pdf"):
                    loader = PyPDFLoader(temp_path)
                elif file.name.endswith(".txt"):
                    loader = TextLoader(temp_path)
                else:
                    st.error("不支持的格式")
                    continue
                
                docs.extend(loader.load())
                os.remove(temp_path)  # 删除临时文件
            
            # 文本分割
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap
            )
            chunks = text_splitter.split_documents(docs)
            
            # 创建向量存储
            if st.session_state.vector_store is None:
                st.session_state.vector_store = FAISS.from_documents(
                    chunks, embeddings
                )
            else:
                st.session_state.vector_store.add_documents(chunks)
            
            st.success(f"已处理 {len(docs)} 个文档,生成 {len(chunks)} 个文本块")
    
    # 问答界面
    query = st.text_input("输入问题:")
    if query and st.session_state.vector_store:
        with st.spinner("搜索中..."):
            # 检索相关文档
            docs = st.session_state.vector_store.similarity_search(
                query, k=search_k
            )
            
            # 构建上下文
            context = "\n".join([d.page_content for d in docs])
            
            # 生成回答
            prompt = f"""
            根据以下上下文回答问题:
            {context}
            
            问题:{query}
            答案:
            """
            
            answer = llm(prompt, temperature=temperature)
            
            # 显示结果
            st.subheader("回答:")
            st.write(answer)
            
            with st.expander("查看参考文档"):
                for i, doc in enumerate(docs, 1):
                    st.write(f"**文档片段 {i}**")
                    st.write(doc.page_content)
                    st.divider()

if __name__ == "__main__":
    main()

运行项目

shell
streamlit run rag_app.py

模型升级

python
# 更换为更强大的模型
model_name = "google/flan-t5-xxl"
# 或者使用OpenAI API
# from langchain.llms import OpenAI
# llm = OpenAI(temperature=0.3)

混合检索

python
from langchain.retrievers import BM25Retriever, EnsembleRetriever

对话历史

python
# 在session_state中保存对话历史
if "history" not in st.session_state:
    st.session_state.history = []

结果缓存

python
@st.cache_data
def get_answer(query):
    # 缓存常见问题答案
    return processed_answer

高级分块策略

python
# 使用语义分块
from langchain_experimental.text_splitter import SemanticChunker