主题
功能说明
- 文档上传与处理:
- 支持PDF/TXT格式
- 自动处理多文件
- 可配置分块参数
- 检索增强功能:
- 使用FAISS向量数据库
- Hugging Face句子嵌入
- 可配置检索结果数量
- 生成功能:
- 使用FLAN-T5基础模型
- 可调节生成温度
- 结合检索上下文生成答案
- 界面功能:
- 实时问答交互
- 显示参考文档片段
- 参数侧边栏配置
- 处理状态提示
rag_app.py
重页面轻功能
python
import streamlit as st
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import os
# 初始化模型
@st.cache_resource
def load_models():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# 使用一个小型的开源生成模型(示例使用FLAN-T5)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
temperature=0.3
)
llm = HuggingFacePipeline(pipeline=pipe)
return embeddings, llm
def main():
st.title("RAG 问答系统 💬")
# 初始化session state
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
# 侧边栏设置
with st.sidebar:
st.header("配置")
uploaded_files = st.file_uploader(
"上传文档(PDF/TXT)",
type=["pdf", "txt"],
accept_multiple_files=True
)
chunk_size = st.number_input("分块大小", 100, 2000, 500)
chunk_overlap = st.number_input("分块重叠", 0, 200, 50)
search_k = st.slider("检索结果数量", 1, 5, 3)
temperature = st.slider("生成温度", 0.0, 1.0, 0.3)
process_button = st.button("处理文档")
# 模型加载
embeddings, llm = load_models()
# 文档处理
if process_button and uploaded_files:
with st.spinner("处理文档中..."):
docs = []
for file in uploaded_files:
# 保存临时文件
temp_path = f"temp_{file.name}"
with open(temp_path, "wb") as f:
f.write(file.getbuffer())
# 根据文件类型加载
if file.name.endswith(".pdf"):
loader = PyPDFLoader(temp_path)
elif file.name.endswith(".txt"):
loader = TextLoader(temp_path)
else:
st.error("不支持的格式")
continue
docs.extend(loader.load())
os.remove(temp_path) # 删除临时文件
# 文本分割
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_documents(docs)
# 创建向量存储
if st.session_state.vector_store is None:
st.session_state.vector_store = FAISS.from_documents(
chunks, embeddings
)
else:
st.session_state.vector_store.add_documents(chunks)
st.success(f"已处理 {len(docs)} 个文档,生成 {len(chunks)} 个文本块")
# 问答界面
query = st.text_input("输入问题:")
if query and st.session_state.vector_store:
with st.spinner("搜索中..."):
# 检索相关文档
docs = st.session_state.vector_store.similarity_search(
query, k=search_k
)
# 构建上下文
context = "\n".join([d.page_content for d in docs])
# 生成回答
prompt = f"""
根据以下上下文回答问题:
{context}
问题:{query}
答案:
"""
answer = llm(prompt, temperature=temperature)
# 显示结果
st.subheader("回答:")
st.write(answer)
with st.expander("查看参考文档"):
for i, doc in enumerate(docs, 1):
st.write(f"**文档片段 {i}**")
st.write(doc.page_content)
st.divider()
if __name__ == "__main__":
main()
运行项目
shell
streamlit run rag_app.py
模型升级
python
# 更换为更强大的模型
model_name = "google/flan-t5-xxl"
# 或者使用OpenAI API
# from langchain.llms import OpenAI
# llm = OpenAI(temperature=0.3)
混合检索
python
from langchain.retrievers import BM25Retriever, EnsembleRetriever
对话历史
python
# 在session_state中保存对话历史
if "history" not in st.session_state:
st.session_state.history = []
结果缓存
python
@st.cache_data
def get_answer(query):
# 缓存常见问题答案
return processed_answer
高级分块策略
python
# 使用语义分块
from langchain_experimental.text_splitter import SemanticChunker