RAG(Retrieval Augmented Generation)

RAG叫做检索增强生成系统,它通过知识库+大语言模型为用户提供信息。RAG系统能充分利用企业自身的数据,并结合LLM,因此适合特定知识的搜索,可用来作为智能客服,知识问答助手等等。它的基本构建方式分为两步:

  1. 知识库的构建,通常需要把企业数据或私有数据转存入向量数据库

  2. 根据用户问题,检索向量数据库,根据检索结果组合提示词并发给大模型,大模型生成回复

RAG系统具体实现

我之前在网上搜索了一些科技/IT等方面的新闻简报等,这些数据存在本地MySQL中,可以作为构建RAG系统的基础。

第一步,读取MySQL数据,并存入向量数据库ChromaDB;在这一步需要注意,不要使用默认的Embedding模型,而使用对中文支持比较好的一个小模型:BAAI/bge-small-zh-v1.5;并且需要解决文本块太大的问题。文本块拆分有多个策略,比如使用固定最大长度,或使用最大长度并在逗号或句号等地方分割。此处使用固定最大长度,为了防止信息丢失,块和块之间还有部分重叠字符,块大小是500字符,重叠50字符。具体代码如下:

import chromadb
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from load_mysql_data import query_data

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from appconfig import COLLECTION_NAME

# 加载 BGE-small-zh-v1.5 模型
model = SentenceTransformer('BAAI/bge-small-zh-v1.5')
client = chromadb.PersistentClient(path="./chroma_db")
text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,  # 每个块的最大字符数
        chunk_overlap=50,  # 块之间的重叠字符数
        length_function=len,
    )

def remove_collection():
    collections = client.list_collections()

    for collection in collections:
        print(f"Existing collection name: {collection.name}")
        if collection.name == COLLECTION_NAME:
            client.delete_collection(COLLECTION_NAME)
            break

# Run only once 
def save_data():
    results = query_data()
    if not results:
        print("No data found.")
        return 
        
    remove_collection() # remove existing collection which has the same name
    collection = client.get_or_create_collection(name=COLLECTION_NAME)

    documents = []
    document_ids = []
    embeddings = []
    i = 0
    while i < len(results):
        row = results[i]
        doc_id = str(row['id'])
        title = row['title']
        pub_date = row['pub_date']
        category = row['category'] if row['category'] else "N/A"
        content = row['content']
        
        # 拆分content 
        chunks = text_splitter.split_text(content)
        chunk_embeddings = model.encode(chunks).tolist()
        chunk_ids = [f"{doc_id}_chunk_{j}" for j in range(len(chunks))]

        documents.extend(chunks)
        document_ids.extend(chunk_ids) 
        embeddings.extend(chunk_embeddings)

        i += 1
        if i % 1000 == 0:
            print(f'Saving data {i} ...')
            collection.add(
                documents=documents,
                embeddings=embeddings,
                ids=document_ids
            )
            documents = []
            document_ids = []
            embeddings = []

    if i % 1000 > 0: 
        collection.add(
            documents=documents,
            embeddings=embeddings,
            ids=document_ids
        )

    print("数据已成功保存到Chroma DB!") 

if __name__ == "__main__":
    save_data()

第二步,结合用户的问题,进行向量搜索,并对结果进行重排。数据向量化之后,可以直接读取,以后只对增量数据写入即可。向量的搜索是一个高维空间的搜索,下面简化为二维表示。其中白色小点表示向量数据库中的数据,绿色小点是用户输入,向量搜索是我们可以搜索离绿色小点最近的n条数据,比如下图中的1~5号数据。

image

这之中,我们虽然可以找到各数据到用户问题的距离,但为了再准确一些,可以使用模型重新排序,通常做法是使用向量搜索更多条数据,重排后选择小部分数据发给大模型。这里使用的重排模型是BAAI/bge-reranker-base,具体代码如下:

import chromadb
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from appconfig import COLLECTION_NAME

# 加载 BGE-small-zh-v1.5 模型
model = SentenceTransformer('BAAI/bge-small-zh-v1.5')
reranker_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-base')
reranker_model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-base') 
client = chromadb.PersistentClient(path="./chroma_db") 

def query_data(query_text): 
    collection = client.get_or_create_collection(name=COLLECTION_NAME)
    print(f"Chroma DB documents count: {collection.count()}")

    query_embedding = model.encode([query_text]).tolist()
    
    # 在集合中搜索相似文档
    results = collection.query(
        query_embeddings=query_embedding,
        n_results=20
    )

    pairs = []
    docs = results['documents'][0]
    for doc in docs:
        pairs.append([query_text, doc])

    # Rerank using bge-reranker
    with torch.no_grad():
        inputs = reranker_tokenizer(
            pairs,
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=512
        )
        scores = reranker_model(**inputs).logits.squeeze()
        scores = torch.sigmoid(scores)  # Convert to probabilities


    # from 20 rows and picked the top 5 for GenAI   
    reranked_results = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)

    return reranked_results[:5]

第三步,把选中的数据发给大模型,基于提示词模版,让它返回一个回答。

import google.generativeai as genai
import os

def create_model(): 
    try:
        genai.configure(api_key=os.environ["GEMINI_API_KEY"])
    except KeyError:
        print("Error: The GEMINI_API_KEY environment variable is not set.")
        exit()

    # Configuration for the model
    generation_config = {
        "temperature": 0.7,
        "top_p": 0.95,
        "top_k": 64,
        "max_output_tokens": 8192,
    }

    # Initialize the Generative Model
    model = genai.GenerativeModel(
        model_name="gemini-2.5-flash",
        generation_config=generation_config,
    )

    return model 
 

def get_prompt(question, text_list):
    context = "\n".join([f"[文本{i+1}]: [{text}]" for i, text in enumerate(text_list)])

    return f"""
    你是一个严谨的信息分析员。请根据下面提供的五段原始文本,综合并回答用户的问题。

    请严格遵守以下指令:
    1. 你的回答必须完全依据提供的原始文本。
    2. 在生成答案时,你需要对关键信息进行引用,并在句末使用类似 [来自文本1]、[来自文本3, 文本4] 的格式注明来源。
    3. 综合所有相关信息,形成一个流畅、有逻辑的段落,而不是一个简单的列表。
    4. 如果没有任何一段文本能支持回答,请直接回复:“根据提供的文本,我无法回答这个问题。”

    [原始文本]
    {context}

    [用户的问题]
    [{question}] 
    """ 

最后,给系统加上UI,可以跟用户交互。这里使用的是streamlit,就不贴了,完整代码请戳这里

image