What is RAG?
RAG, or Retrieval-Augmented Generation, is a state-of-the-art technique that enhances the capabilities of AI models by combining retrieval mechanisms with generative models. It allows models to fetch and utilize external data dynamically, resulting in outputs that are more accurate and contextually enriched. This hybrid approach leverages a database of pre-existing knowledge, improving the performance of generative tasks over traditional methods that rely solely on the data learned during training. RAG's ability to access a vast repository of information in real-time makes it an invaluable tool for generating reliable and informed responses.
Why Local RAG in Core Python Matters
Implementing RAG locally in core Python offers several significant advantages:
- Total Control: By building RAG systems from the ground up using core Python, developers gain complete control over every aspect of the implementation. This level of control allows for fine-tuning and customization that's not always possible with pre-built solutions.
- Explainability: A local implementation enables developers to understand and explain every step of the RAG process. This transparency is crucial for debugging, optimization, and building trust in AI systems, especially in sensitive applications.
- Privacy and Security: Local implementations can process data without sending it to external services, ensuring data privacy and compliance with security regulations.
- Cost-Effectiveness: By leveraging local resources and open-source libraries, developers can create powerful RAG systems without relying on potentially expensive cloud-based services.
- Learning and Innovation: Building RAG from scratch in Python provides invaluable insights into the inner workings of these systems, fostering innovation and enabling developers to experiment with novel approaches.
The Code Structure
Let's dive into the key components of implementing RAG locally in core Python, using the provided code as a foundation.
Initializing the Vector DB
import chromadb
from chromadb.config import Settings
class Raghelper():
def __init__(self, path):
self.DB_PATH = path
os.makedirs(self.DB_PATH, exist_ok=True)
self.client = chromadb.PersistentClient(
path=self.DB_PATH,
settings=Settings(
allow_reset=True,
is_persistent=True
),
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
)
self.collection = self.client.get_or_create_collection(name=self.collection_name, embedding_function=self.default_ef)In this section, we initialize ChromaDB, a vector database crucial for efficient similarity searches. The PersistentClient ensures that our data persists across sessions, providing a robust foundation for our RAG system. We create a collection within the database, which will store our document embeddings.
This initialization ensures that the database is stored at the specified path, allowing for persistence across sessions. The use of a persistent client guarantees that the data is available even after the program terminates.
Defining the Chunk Size
def split_documents(self, document, chunk_size="m"):
sentences = sent_tokenize(document)
chunk_sizes = {"s": 30, "m": 50, "l": 75, "xl": 100, "xxl": 130, "all": float('inf')}
max_length = chunk_sizes.get(chunk_size, 30)
# ... (chunking logic)Chunk size plays a crucial role in RAG systems. It determines how documents are split before being embedded and stored in the vector database. The choice of chunk size affects:
- Retrieval Granularity: Smaller chunks allow for more precise retrieval but may lack context, while larger chunks provide more context but might introduce noise.
- Processing Efficiency: Smaller chunks are faster to process and embed but result in more database entries.
- Query Relevance: The chunk size should align with the expected query length and complexity to ensure relevant retrievals.
Our implementation offers flexible chunk sizes, allowing users to optimize for their specific use case.
Converting Text to Vectors
default_ef = embedding_functions.DefaultEmbeddingFunction()
def add_chunk(self, chunk, index, fname, meta_data):
self.collection.add(documents=[chunk], ids=[f"{index}_{fname}"], metadatas=[meta_data])The conversion of text to vectors is handled by ChromaDB's embedding function. We use the DefaultEmbeddingFunction, which typically employs a pre-trained model to convert text into high-dimensional vectors. These vectors capture the semantic meaning of the text, enabling efficient similarity searches.
When adding chunks to our collection, each chunk is automatically converted to a vector and stored alongside metadata, creating a searchable knowledge base for our RAG system.
Retrieving Based on Vector Queries and Distance Queries
Retrieval in RAG involves querying the vector database to find the most relevant information pieces based on a vector similarity measure. The sortDocumentsByDistance function plays a crucial role in this process, filtering and sorting retrieved documents by their proximity to the query vector. This ensures the generated responses are closely aligned with the input query.
def sortDocumentsByDistance(self, data, threshold=0.9):
ids = data['ids'][0]
distances = data['distances'][0]
metadatas = data['metadatas'][0]
documents = data['documents'][0]
# Zip the data together for sorting and filtering
zipped_data = list(zip(distances, ids, metadatas, documents))
filtered_data = [item for item in zipped_data if item[0] <= threshold]
sorted_data = sorted(filtered_data, key=lambda x: x[0])
if sorted_data:
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*sorted_data)
data['ids'][0] = list(sorted_ids)
data['distances'][0] = list(sorted_distances)
data['metadatas'][0] = list(sorted_metadatas)
data['documents'][0] = list(sorted_documents)This function filters documents with a distance score below a specified threshold, ensuring only the most relevant results are returned. By sorting these documents, developers can prioritize responses that are most contextually appropriate.
Retrieving for context
The queryCollection function is the heart of the retrieval process in this RAG implementation. It allows querying both the main text collection and an image collection, providing a comprehensive retrieval mechanism.
def queryCollection(self, query, token_size=4000, search_string=""):
# print("Query", query)
docs = []
try:
n_results = int(self.collection.count() * (0.95))
if search_string == "":
results = self.collection.query(
query_texts=[query],
n_results=n_results,
include=["documents", "distances", "metadatas"],
)
else:
results = self.collection.query(
query_texts=[query],
where_document={"$contains": search_string},
include=["documents", "distances", "metadatas"],
)
results = self.sortDocumentsByDistance(results, threshold=2)
total_tokens = 0
for doc in results["documents"][0]:
count_tokens = self.count_tokens(doc)
if total_tokens + count_tokens > token_size:
break
total_tokens += count_tokens
docs.extend(self.vector_to_sentences(doc))
except Exception as e:
print("Exception in Query", e)
image_results=[]
try:
n_results = int(self.image_collection.count()*0.9)
# print("Rresults to get",n_results)
image_data = self.image_collection.query(
query_texts=[query],
n_results=n_results, # Adjust this number as needed
include=["documents","distances", "metadatas"]
)
dt_Ctr=0
for im_data in image_data:
doc = image_data["documents"][0][dt_Ctr]
desc = image_data["metadatas"][0][dt_Ctr]["desc"]
image_results.append({"src":doc,"desc":desc})
dt_Ctr=dt_Ctr+1
if(dt_Ctr>15):
break
except Exception as e:
print("error in getting images",e)
return docs,image_resultsThe function begins by querying the main collection with the provided query text. It uses the query_texts parameter to search the database and retrieve the most relevant documents. If a search_string is provided, the query is further filtered to include only documents containing that string. The retrieved documents are then sorted based on their distance from the query vector to prioritize the most relevant results.
After sorting, the documents are tokenized, and their total length is checked against the specified token_size limit. This ensures that the retrieved content does not exceed the desired length, which is particularly important for LLM inputs where context lenght is a concern.
The function also queries an image collection, retrieving images related to the query and returning them alongside the text results. This dual retrieval approach makes the system versatile, capable of handling both text and image data.
Finally, the vector_to_sentences function is used to clean and tokenize the retrieved text, splitting it into individual sentences for easier processing and analysis. This function ensures that the final output is well-structured and ready for use in downstream applications.
The source code : Free for use
Here is my source code, sharing it for free use and please do comment and let me know if you find some better alternatives on the overall process or any flaws that I can improve upon
I have put this code live to help me build some crazy LLM based applications, one of them is able to create a very enriched long from (6000+) words articles that are structural and updated with the most updated knowledgebase-but thats for a later time, for now feel free to use this code and play with local RAG.
The approach is store text and images in seperate collections, for images I am only storing the URLs and the ALT tags which is the query string to find the relevant images.
import json
import re
import time
import chromadb
import chromadb.utils.embedding_functions as embedding_functions
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE, Settings
from lib.filereader import FileReader
import nltk
from nltk.tokenize import sent_tokenize
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
import tempfile
class Raghelper():
# Class variables for database path, table, and collection names
DB_PATH = ""
table_name = "rag_data"
collection_name = "rag_collection"
image_collection_name="rag_images"
# Default embedding function used for ChromaDB collections
default_ef = embedding_functions.DefaultEmbeddingFunction()
# ChromaDB client and collections
client = None
collection = None
image_collection=None
# Thread pool executor for handling asynchronous tasks
executor = None
def __init__(self, path):
"""Initialize the Raghelper instance with a database path and set up the ChromaDB client and collections."""
self.DB_PATH = path
# Ensure the specified directory exists
os.makedirs(self.DB_PATH, exist_ok=True)
self.client = chromadb.PersistentClient(
path=self.DB_PATH,
settings=Settings(
allow_reset=True,
is_persistent=True
),
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
)
# Get or create the main collection and image collection
self.collection = self.client.get_or_create_collection(name=self.collection_name, embedding_function=self.default_ef)
self.image_collection = self.client.get_or_create_collection(
name=self.image_collection_name,
embedding_function=self.default_ef
)
# Initialize a thread pool with 4 worker threads
self.executor = ThreadPoolExecutor(max_workers=4)
@contextmanager
def get_client(self):
"""Context manager for safely obtaining and releasing the ChromaDB client."""
try:
if self.client is None:
self.client = chromadb.PersistentClient(
path=self.DB_PATH,
settings=Settings(
allow_reset=True,
is_persistent=True
),
tenant=DEFAULT_TENANT,
database=DEFAULT_DATABASE,
)
yield self.client
finally:
if self.client:
self.client.clear_system_cache()
self.client = None
def close(self):
"""Shutdown the executor and close all database connections."""
if self.executor:
self.executor.shutdown(wait=True)
self.close_all_connections()
def __del__(self):
"""Ensure cleanup by calling close method when the instance is deleted."""
self.close()
def deleteRag(self):
"""Delete the existing collections and reset the ChromaDB client."""
try:
self.client.delete_collection(self.collection_name)
self.client.delete_collection(self.image_collection_name)
self.client.reset()
self.client.clear_system_cache()
except Exception as e:
print("error deleting collection", e)
# Recreate the collections after deletion
self.collection = self.client.get_or_create_collection(name=self.collection_name, embedding_function=self.default_ef)
self.image_collection = self.client.get_or_create_collection(name=self.image_collection_name, embedding_function=self.default_ef)
def deleteFileFromRag(self, file_name):
"""Delete a specific file from the main collection."""
try:
self.collection.delete(where={"filename": file_name})
except Exception as e:
print("error deleting", file_name, e)
def add_images_to_collection(self, file_name, file_path):
"""Add images and their metadata to the image collection."""
images_data = []
try:
if isinstance(file_path, str):
full_path = os.path.join(file_path, file_name)
with open(full_path, 'r', encoding='utf-8') as file:
data = file.read()
images_data = json.loads(data)
images_data = images_data[:500] # Limit to first 500 images
docs = []
metadatas = []
ids = []
ctr = 1
for image in images_data:
docs.append(image["src"])
metadatas.append({"desc": image["alt"]})
ids.append("id_" + str(ctr))
ctr += 1
# Add the images to the collection
self.image_collection.add(documents=docs, metadatas=metadatas, ids=ids)
except Exception as e:
print("Error adding images", e)
def addFileToRag(self, file_name, file_path, chunk_size):
"""Process and add a file to the main collection in chunks."""
if self.isFileAddedinVectorDB(file_name):
return
filereader = FileReader()
content = ""
temp_file_created = False
temp_file_path = ""
try:
if isinstance(file_path, str):
full_path = os.path.join(file_path, file_name)
with open(full_path, 'r', encoding='utf-8') as file:
content = file.read()
else:
# Handle file-like objects by writing to a temporary file
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, file_name)
with open(temp_file_path, 'w', encoding='utf-8') as temp_file:
file_path.seek(0)
temp_file.write(file_path.read())
temp_file_created = True
with open(temp_file_path, 'r', encoding='utf-8') as file:
content = file.read()
content = filereader.clean_data(content)
chunks = self.split_documents(content, chunk_size=chunk_size)
meta_data = {"filename": file_name}
self.add_chunks_to_collection(chunks, file_name, meta_data)
except Exception as e:
print(f"Error processing file {file_name}: {e}")
finally:
if temp_file_created and temp_file_path:
try:
os.remove(temp_file_path)
except Exception as e:
print(f"Error removing temporary file {temp_file_path}: {e}")
def add_chunks_to_collection(self, chunks, fname, meta_data):
"""Add chunks of a document to the collection in batches."""
batch_size = 30 # Number of chunks to add at once
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
# Submit tasks to the executor for asynchronous processing
futures = {self.executor.submit(self.add_chunk, chunk, f"{j + i}", fname, meta_data): j for j, chunk in enumerate(batch)}
for future in as_completed(futures):
try:
future.result() # Ensure each future completes without error
except Exception as e:
print(f"Error in future: {e}")
time.sleep(0.1) # Small delay between batches to avoid overloading
def add_chunk(self, chunk, index, fname, meta_data):
"""Add a single chunk to the collection."""
try:
print(f"Adding chunk {index} for {fname}")
self.collection.add(documents=[chunk], ids=[f"{index}_{fname}"], metadatas=[meta_data])
except Exception as e:
print(f"Error adding chunk {index} for {fname}: {e}")
def isFileAddedinVectorDB(self, filename):
"""Check if a file has already been added to the collection."""
res = self.collection.get(where={"filename": filename}, include=["metadatas"])
file_exist = len(res["ids"]) > 0
print(filename, file_exist)
return file_exist
def split_documents(self, document, chunk_size="m"):
"""Split a document into chunks based on sentence tokenization."""
sentences = sent_tokenize(document)
# Define chunk sizes
chunk_sizes = {"s": 30, "m": 50, "l": 75, "xl": 100, "xxl": 130, "all": float('inf')}
max_length = chunk_sizes.get(chunk_size, 30)
if chunk_size == "all":
return [" ".join(sentences)]
segments = []
current_segment = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence.split())
if current_length + sentence_length > max_length:
segments.append(" ".join(current_segment))
current_segment = []
current_length = 0
current_segment.append(sentence)
current_length += sentence_length
if current_segment:
segments.append(" ".join(current_segment))
return segments
def sortDocumentsByDistance(self, data, threshold=0.9):
"""Sort and filter documents by their distance score."""
ids = data['ids'][0]
distances = data['distances'][0]
metadatas = data['metadatas'][0]
documents = data['documents'][0]
# Zip the data together for sorting and filtering
zipped_data = list(zip(distances, ids, metadatas, documents))
filtered_data = [item for item in zipped_data if item[0] <= threshold]
sorted_data = sorted(filtered_data, key=lambda x: x[0])
if sorted_data:
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*sorted_data)
data['ids'][0] = list(sorted_ids)
data['distances'][0] = list(sorted_distances)
data['metadatas'][0] = list(sorted_metadatas)
data['documents'][0] = list(sorted_documents)
else:
data['ids'][0] = []
data['distances'][0] = []
data['metadatas'][0] = []
data['documents'][0] = []
return data
def count_tokens(self, text):
text = text.strip()
words = re.findall(r'\w+', text.lower())
word_count = len(words)
punctuation_count = len(re.findall(r'[^\w\s]', text))
estimated_tokens = int(word_count * 1.3) + punctuation_count
return estimated_tokens
def queryCollection(self, query, token_size=4000, search_string=""):
# print("Query", query)
docs = []
try:
n_results = int(self.collection.count() * (0.95))
if search_string == "":
results = self.collection.query(
query_texts=[query],
n_results=n_results,
include=["documents", "distances", "metadatas"],
)
else:
results = self.collection.query(
query_texts=[query],
where_document={"$contains": search_string},
include=["documents", "distances", "metadatas"],
)
results = self.sortDocumentsByDistance(results, threshold=2)
total_tokens = 0
for doc in results["documents"][0]:
count_tokens = self.count_tokens(doc)
if total_tokens + count_tokens > token_size:
break
total_tokens += count_tokens
docs.extend(self.vector_to_sentences(doc))
except Exception as e:
print("Exception in Query", e)
image_results=[]
try:
n_results = int(self.image_collection.count()*0.9)
# print("Rresults to get",n_results)
image_data = self.image_collection.query(
query_texts=[query],
n_results=n_results, # Adjust this number as needed
include=["documents","distances", "metadatas"]
)
dt_Ctr=0
for im_data in image_data:
doc = image_data["documents"][0][dt_Ctr]
desc = image_data["metadatas"][0][dt_Ctr]["desc"]
image_results.append({"src":doc,"desc":desc})
dt_Ctr=dt_Ctr+1
if(dt_Ctr>15):
break
except Exception as e:
print("error in getting images",e)
return docs,image_results
def vector_to_sentences(self, text):
text = re.sub(r'\s+', ' ', text)
text = ''.join(char for char in text if char.isprintable())
sentences = nltk.sent_tokenize(text)
return [sent.strip() for sent in sentences if sent.strip()]
def close_all_connections(self):
if self.client:
self.client.clear_system_cache()