Implementing Retrieval-Augmented Generation with LangChain
Understanding Retrieval-Augmented Generation (RAG)
RAG is a methodology introduced to address knowledge-intensive NLP tasks. It merges two distinct forms of memory: parametric memory, encoded within pre-trained models like large language models (LLMs), and non-parametric memory, which consists of external, searchable information stored as vector representations. This hybrid approach enables models to access and incorporate specific, up-to-date information beyond their pre-trained knowledge base, significantly improving accuracy for tasks like question answering.
Addressing LLM Shortcomings with RAG
Large langugae models excel in general conversation but face notable limitations:
- Generating plausible but incorrect information, often called "hallucinations".
- Providing insufficient or outdated answers in specialized or rapidly evolving domains due to training data constraints.
- Exhibiting inconsistency in responses to identical queries.
- Lacking awareness of recent events or private knowledge.
RAG mitigates these issues by dynamically retrieving relevant context from an external knowledge source, analogous to providing a new employee with access to internal documentation. The model leverages its inherent generation capabilities alongside this retrieved context to formulate accurate and current responses.
Buildinng RAG Systems with LangChain
LangChain provides a structured framework for implementing RAG workflows.
RAG Architecture Overview
A typical RAG pipeline consists of two main phases: offline indexing and online retrieval-generation.
Offline Indexing Phase
- Data Loading: Extract textual content from various sources (PDFs, Word docs, databases).
- Chunking: Split the extracted text into smaller, manageable segments (chunks). This granularity aids in precise retrieval and reduces context length.
- Embedding: Generate vector representations (embeddings) for each text chunk using an embedding model.
- Vector Storage: Index and store these embeddings in a dedicated vector database for efficient similarity search. This process creates the non-parametric, vector-based memory component.
Online Retrieval-Generation Phase When a user query is received:
- The query is converted into an embedding vector.
- This query embedding is used to search the vector database for the most semantically relevant text chunks (context).
- The retrieved context and the original query are combined into a structured prompt.
- The LLM processes this augmented prompt to generate the final answer. Beyond improving answer quality, this architecture offers an engineering benefit by reducing the amount of context the model must process, thereby lowering token usage and computational cost.
Prototyping a RAG System with LangChain
First, consider the baseline performance of an LLM without RAG.
from langchain_community.llms import LlamaCpp
model_path = "/path/to/model/mixtral-8x7b-instruct-v0.1.Q8_0.gguf"
llm = LlamaCpp(model_path=model_path)
response = llm.invoke("How many times did Sun Wukong battle the White Bone Demon?")
print(response)
The model's response may be inaccurate or fail to directly answer the question, highlighting the need for external grounding.
Implementing RAG requires a vector store (e.g., Chroma) and an embedding model. Offline Indexing Implementation:
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import LlamaCppEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
# Load documents
loader = DirectoryLoader('/path/to/knowledge/files', glob="**/*.txt")
documents = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=200)
chunks = text_splitter.split_documents(documents)
# Create embeddings and vector store
embedding_model = LlamaCppEmbeddings(model_path=model_path)
vector_db = Chroma.from_documents(documents=chunks, embedding=embedding_model)
Online Retrieval and Generation Implementation:
import os
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# Configure LangSmith for tracing (optional)
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = "your_api_key"
# Create a retriever from the vector store
retriever = vector_db.as_retriever()
# Fetch a pre-defined RAG prompt template
rag_prompt = hub.pull("rlm/rag-prompt")
def combine_context(retrieved_docs):
return "\n\n".join([doc.page_content for doc in retrieved_docs])
# Construct the RAG chain
rag_chain = (
{"context": retriever | combine_context, "input": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
# Invoke the chain
answer = rag_chain.invoke("How many times did Sun Wukong battle the White Bone Demon?")
print(answer)
The quality of the answer depends on the relevance of the indexed knowledge. Observability tools like LangSmith provide valuable insights into the retrieval and generation steps.
Integrating a retriever enhances LLM capabilities but introduces new design considerations:
- Selecting appropriate external knowledge sources for the specific domain.
- Preprocessing data effectively, including optimal chunking strategies and size selection.
- Designing effective prompt templates that efficiently utilize both the retrieved context and the LLM's generative power, while managing context length. These challenges require domain-specific experimentation and iterative refinement based on performance feedback.