Initial commit

This commit is contained in:
Alex Selimov 2025-03-21 10:09:07 -04:00
commit e0d0962fc5
20 changed files with 5731 additions and 0 deletions

6
README.md Normal file
View File

@ -0,0 +1,6 @@
# code_rag
This is intended to be a fully local Retrieval Augmented Generation tool for software development.
This will be loadable as a neovim plugin, where we retrieve the current workspace files and add them to a persistent chroma vector store.
This tool will expose an Ollama API to query locally running Ollama models with prompts augmented by relevant context from the current code.

97
document_tracker.json Normal file
View File

@ -0,0 +1,97 @@
{
"/tmp/tmpje6soo0x/documents/python_info.txt": {
"chunk_ids": [
"c0b04224-f72e-4f6e-bef1-067f1cf5c716"
]
},
"/tmp/tmpy2siaefv/documents/python_info.txt": {
"chunk_ids": [
"27dc4179-c908-47dd-8e56-2e2b3c403faf"
]
},
"/tmp/tmp_fjaagdf/documents/python_info.txt": {
"chunk_ids": [
"3c147f7f-bd5a-456b-8bee-f9a43b6721c9"
]
},
"/tmp/tmpfbgc8zgm/documents/python_info.txt": {
"chunk_ids": [
"34eb0065-8a18-4144-890c-2e099d334b88"
]
},
"/tmp/tmp8u9_vv_c/documents/python_info.txt": {
"chunk_ids": [
"6cd7655c-8fa2-437c-b8fc-5690c928b89e"
]
},
"/tmp/tmpm_quzgav/documents/python_info.txt": {
"chunk_ids": [
"9e452f02-fe93-4b65-b883-f585df25dd4d"
]
},
"/tmp/tmpjrncrqpe/documents/doc1.txt": {
"chunk_ids": [
"81cea007-7a54-4d19-869d-3643fd8cf257"
]
},
"/tmp/tmpjrncrqpe/documents/doc2.txt": {
"chunk_ids": [
"eb1bd307-e065-49a9-8d73-007403d13bd0"
]
},
"/tmp/tmpjrncrqpe/documents/doc3.txt": {
"chunk_ids": [
"3adccb38-28df-496d-b5a1-fa4555ea1fb5"
]
},
"/tmp/tmptyee73q5/documents/doc1.txt": {
"chunk_ids": [
"319dc9c4-0a96-4c49-9306-c9c11e94b613"
]
},
"/tmp/tmptyee73q5/documents/doc2.txt": {
"chunk_ids": [
"d20230ce-0753-4303-a9a2-d14379bb3d91"
]
},
"/tmp/tmptyee73q5/documents/doc3.txt": {
"chunk_ids": [
"ceb9e3c7-c84f-49b4-8da4-ec96ac39f7e5"
]
},
"/tmp/tmpq7xsv0n6/documents/doc1.txt": {
"chunk_ids": [
"2a7d43bd-3f25-46b2-8f9c-b5108afa06df"
]
},
"/tmp/tmpq7xsv0n6/documents/doc2.txt": {
"chunk_ids": [
"6e38e121-14a6-4c07-9717-b7babb0fbb2f"
]
},
"/tmp/tmpq7xsv0n6/documents/doc3.txt": {
"chunk_ids": [
"cbc80fab-f4bc-47c4-a482-36afab9a649e"
]
},
"/tmp/tmpi80m0y6n/documents/doc1.txt": {
"chunk_ids": [
"b145e673-0a4c-4058-a6d1-e1237bd2d9af"
]
},
"/tmp/tmpi80m0y6n/documents/doc2.txt": {
"chunk_ids": [
"32082799-0089-4a3d-b4ad-ea3ff0084212"
]
},
"/tmp/tmpi80m0y6n/documents/doc3.txt": {
"chunk_ids": [
"0902fef4-ce95-4a73-9e0c-90c1581f80b2"
]
},
"/tmp/tmpi80m0y6n/documents/python_info.txt": {
"chunk_ids": [
"009da108-76ba-4ef7-9c6b-f862f577ce23"
]
}
}

4834
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

19
pyproject.toml Normal file
View File

@ -0,0 +1,19 @@
[project]
name = "code-rag"
version = "0.1.0"
description = "Simple RAG implementation for use with neovim"
authors = [{ name = "Alex Selimov", email = "alex@alexselimov.com" }]
readme = "README.md"
requires-python = ">=3.9,<4.0"
dependencies = ["langchain (>=0.3.21,<0.4.0)", "ollama (>=0.4.7,<0.5.0)", "langchain-community (>=0.3.20,<0.4.0)", "langchain-ollama (>=0.2.3,<0.3.0)", "chromadb (>=0.6.3,<0.7.0)", "unstructured (>=0.17.2,<0.18.0)"]
[tool.poetry]
packages = [{ include = "code_rag", from = "src" }]
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.5"
[build-system]
requires = ["poetry-core>=2.0.0,<3.0.0"]
build-backend = "poetry.core.masonry.api"

0
src/code_rag/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

174
src/code_rag/doc_tracker.py Normal file
View File

@ -0,0 +1,174 @@
import os
import hashlib
import json
from datetime import datetime
import time
class DocMetaData:
"""
Class that stores document meta data. Using this so we know that the MetaData is always
initialized and can avoid a lot of checking whether keys exist or not
"""
def __init__(self, mod_time, hash, last_updated, chunk_ids):
self.mod_time = mod_time
self.hash = hash
self.last_updated = last_updated
self.chunk_ids = chunk_ids
def __eq__(self, other):
if isinstance(other, DocMetaData):
return (
self.mod_time == other.mod_time
and self.hash == other.hash
and self.last_updated == other.last_updated
and self.chunk_ids == other.chunk_ids
)
return NotImplemented
def update_chunks(self, chunks):
self.chunk_ids = chunks
def to_dict(self):
return {
"mod_time": self.mod_time,
"hash": self.hash,
"last_updated": self.last_updated,
"chunk_ids": self.chunk_ids,
}
@classmethod
def from_file(cls, file_path):
return cls(
os.path.getmtime(file_path), calculate_file_hash(file_path), time.time(), []
)
@classmethod
def from_dict(cls, input_dict):
return cls(
input_dict["mod_time"],
input_dict["hash"],
input_dict["last_updated"],
input_dict["chunk_ids"],
)
class DocumentTracker:
"""
Tracks document changes using file hashes, modification times, and chunk IDs
"""
def __init__(self, tracking_file):
self.tracking_file = tracking_file
self.doc_info = self._load_tracking_data()
def _load_tracking_data(self):
"""Load existing tracking data if available"""
doc_info = dict()
if os.path.exists(self.tracking_file):
with open(self.tracking_file, "r") as f:
serialized = json.load(f)
for k, v in serialized.items():
doc_info[k] = DocMetaData.from_dict(v)
return doc_info
def _save_tracking_data(self):
"""Save tracking data to file"""
output = dict()
for k, v in self.doc_info.items():
output[k] = v.to_dict()
with open(self.tracking_file, "w") as f:
json.dump(output, f, indent=2)
def get_changed_files(self, directory, file_extension=".txt"):
"""
Detect new, modified, and deleted files
Returns: dict with 'new', 'modified', and 'deleted' lists
"""
current_file_mod_times = {}
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(file_extension):
file_path = os.path.join(root, file)
mod_time = os.path.getmtime(file_path)
current_file_mod_times[file_path] = mod_time
new_files = []
modified_files = []
# Check for new or modified files
for file_path, mod_time in current_file_mod_times.items():
if file_path not in self.doc_info:
new_files.append(file_path)
elif mod_time > self.doc_info[file_path].mod_time:
# Check if content actually changed using hash
current_hash = calculate_file_hash(file_path)
if current_hash != self.doc_info[file_path].hash:
modified_files.append(file_path)
self.doc_info[file_path].hash = current_hash
self.doc_info[file_path].mod_time = mod_time
# Check for deleted files
deleted_files = [f for f in self.doc_info if f not in current_file_mod_times]
# Update tracking information for new files
for file_path in new_files:
self.doc_info[file_path] = DocMetaData(
current_file_mod_times[file_path],
calculate_file_hash(file_path),
time.time(),
[], # Will store IDs of chunks derived from this document
)
# Update last_updated for modified files
for file_path in modified_files:
self.doc_info[file_path].last_updated = datetime.now().isoformat()
self._save_tracking_data()
return {"new": new_files, "modified": modified_files, "deleted": deleted_files}
def update_chunk_mappings(self, file_path, chunk_ids):
"""Store the chunk IDs associated with a document"""
if file_path not in self.doc_info:
self.doc_info[file_path] = DocMetaData.from_file(file_path)
self.doc_info[file_path].chunk_ids = chunk_ids
self._save_tracking_data()
def get_chunks_to_delete(self, deleted_files):
"""Get all chunk IDs associated with deleted files"""
chunks_to_delete = []
for file_path in deleted_files:
if file_path in self.doc_info and self.doc_info[file_path].chunk_ids:
chunks_to_delete.extend(self.doc_info[file_path].chunk_ids)
# Remove the file from tracking after processing
del self.doc_info[file_path]
self._save_tracking_data()
return chunks_to_delete
def get_chunks_for_modified_files(self, modified_files):
"""Get chunk IDs for modified files that need to be deleted before re-indexing"""
chunks_to_delete = []
for file_path in modified_files:
if file_path in self.doc_info and self.doc_info[file_path].chunk_ids:
chunks_to_delete.extend(self.doc_info[file_path].chunk_ids)
# Clear the chunk IDs (will be updated with new ones)
self.doc_info[file_path].chunk_ids = []
self._save_tracking_data()
return chunks_to_delete
def calculate_file_hash(file_path):
"""Calculate MD5 hash of file contents"""
hasher = hashlib.md5()
with open(file_path, "rb") as f:
buf = f.read(65536)
while len(buf) > 0:
hasher.update(buf)
buf = f.read(65536)
return hasher.hexdigest()

228
src/code_rag/rag.py Normal file
View File

@ -0,0 +1,228 @@
import os
import uuid
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import Ollama
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from code_rag.doc_tracker import DocumentTracker
class RAG:
def __init__(self, docs_dir, db_dir, tracker_file):
self.docs_dir = docs_dir
self.db_dir = db_dir
self.tracker = DocumentTracker(tracker_file)
def process_documents(self, files, text_splitter):
"""Process document files into chunks with tracking metadata"""
all_chunks = []
file_chunk_map = {}
for file_path in files:
# Load the document
loader = TextLoader(file_path)
documents = loader.load()
# Add source metadata
for doc in documents:
doc.metadata["source"] = file_path
doc.metadata["source_id"] = file_path # For easier identification
# Split the document
chunks = text_splitter.split_documents(documents)
# Generate and track chunk IDs
chunk_ids = []
for chunk in chunks:
chunk_id = str(uuid.uuid4())
chunk.metadata["chunk_id"] = chunk_id
chunk_ids.append(chunk_id)
# Store chunk mappings if tracker is provided
if self.tracker:
self.tracker.update_chunk_mappings(file_path, chunk_ids)
file_chunk_map[file_path] = chunks
all_chunks.extend(chunks)
return all_chunks, file_chunk_map
def create_vector_db(self, extension=".txt", force_refresh=False):
"""
Create or update a vector database, with complete handling of changes
"""
# Text splitter configuration
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
)
# Create embeddings
embeddings = OllamaEmbeddings(model="nomic-embed-text")
if force_refresh:
print("Force refresh: Processing all documents")
# Load all documents
loader = DirectoryLoader(self.docs_dir, glob=f"**/*{extension}")
all_documents = loader.load()
# Add unique IDs to each document
for doc in all_documents:
doc.metadata["source"] = os.path.abspath(doc.metadata["source"])
doc.metadata["source_id"] = doc.metadata["source"]
# Split documents
chunks = text_splitter.split_documents(all_documents)
# Add chunk IDs and update tracker
file_chunk_map = {}
for chunk in chunks:
chunk_id = str(uuid.uuid4())
chunk.metadata["chunk_id"] = chunk_id
source = chunk.metadata["source"]
if source not in file_chunk_map:
file_chunk_map[source] = []
file_chunk_map[source].append(chunk_id)
# Update tracker with chunk mappings
for file_path, chunk_ids in file_chunk_map.items():
self.tracker.update_chunk_mappings(file_path, chunk_ids)
print(
f"Processing {len(all_documents)} documents with {len(chunks)} chunks"
)
# Create new vector store
vectorstore = Chroma.from_documents(
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
)
vectorstore.persist()
print(f"Created new vector database at {self.db_dir}")
return vectorstore
# Get changes since last update
changed_files = self.tracker.get_changed_files(self.docs_dir)
if not any(changed_files.values()):
print("No document changes detected")
# Load existing vector store if available
if os.path.exists(self.db_dir):
return Chroma(
persist_directory=self.db_dir, embedding_function=embeddings
)
else:
print("No vector database exists. Creating from all documents...")
return self.create_vector_db(force_refresh=True)
# Process changes
print(
f"Changes detected - New: {len(changed_files['new'])}, Modified: {len(changed_files['modified'])}, Deleted: {len(changed_files['deleted'])}"
)
# Load existing vector store if it exists
if os.path.exists(self.db_dir):
vectorstore = Chroma(
persist_directory=self.db_dir, embedding_function=embeddings
)
# 1. Handle deleted documents
if changed_files["deleted"]:
chunks_to_delete = self.tracker.get_chunks_to_delete(
changed_files["deleted"]
)
if chunks_to_delete:
print(
f"Removing {len(chunks_to_delete)} chunks from deleted documents"
)
# Delete the chunks from vector store
vectorstore._collection.delete(
where={"chunk_id": {"$in": chunks_to_delete}}
)
# 2. Handle modified documents (delete old chunks first)
chunks_to_delete_modified = self.tracker.get_chunks_for_modified_files(
changed_files["modified"]
)
if chunks_to_delete_modified:
print(
f"Removing {len(chunks_to_delete_modified)} chunks from modified documents"
)
vectorstore._collection.delete(
where={"chunk_id": {"$in": chunks_to_delete_modified}}
)
# 3. Process new and modified documents
files_to_process = changed_files["new"] + changed_files["modified"]
if files_to_process:
chunks, _ = self.process_documents(
files_to_process, text_splitter, self.tracker
)
print(f"Adding {len(chunks)} new chunks to the vector store")
vectorstore.add_documents(chunks)
else:
# If no existing DB, create from all documents
print("No existing vector database. Creating from all documents...")
return self.create_vector_db(force_refresh=True)
# Persist changes
vectorstore.persist()
print(f"Vector database updated at {self.db_dir}")
return vectorstore
def setup_rag(self, model_name="llama3"):
"""
Set up the RAG system with an existing vector database
"""
# Load the embeddings
embeddings = OllamaEmbeddings(model="nomic-embed-text")
# Load the vector store
vectorstore = Chroma(
persist_directory=self.db_dir, embedding_function=embeddings
)
# Create a retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Set up the LLM
llm = Ollama(model=model_name)
# Create a custom prompt template
template = """
Answer the question based on the context provided. If you don't know the answer,
just say you don't know. Don't try to make up an answer.
Context: {context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
input_variables=["context", "question"], template=template
)
# Create the RAG chain
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
)
return rag_chain
def query_rag(self, rag_chain, query):
"""
Query the RAG system
"""
response = rag_chain.invoke({"query": query})
return response["result"]

0
tests/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

66
tests/fixtures.py Normal file
View File

@ -0,0 +1,66 @@
import pytest
import shutil
import tempfile
import os
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
@pytest.fixture
def docs_dir(temp_dir):
"""Create a temporary documents directory"""
docs_dir = os.path.join(temp_dir, "documents")
os.makedirs(docs_dir)
yield docs_dir
@pytest.fixture
def db_dir(temp_dir):
"""Create a temporary vector database directory"""
db_dir = os.path.join(temp_dir, "vector_db")
os.makedirs(db_dir)
yield db_dir
@pytest.fixture
def tracker_file(temp_dir):
"""Create a temporary tracker file"""
tracker_path = os.path.join(temp_dir, "test_tracker.json")
yield tracker_path
# Clean up after tests
if os.path.exists(tracker_path):
os.remove(tracker_path)
@pytest.fixture
def sample_docs(docs_dir):
"""Create sample text documents for testing"""
# Create a few sample documents
doc1_path = os.path.join(docs_dir, "doc1.txt")
doc2_path = os.path.join(docs_dir, "doc2.txt")
doc3_path = os.path.join(docs_dir, "doc3.txt")
with open(doc1_path, "w") as f:
f.write("This is a sample document about artificial intelligence. " * 10)
with open(doc2_path, "w") as f:
f.write("This document discusses machine learning concepts. " * 10)
with open(doc3_path, "w") as f:
f.write("Natural language processing is a field of AI. " * 10)
return [doc1_path, doc2_path, doc3_path]
@pytest.fixture
def text_splitter():
"""Create a mock text splitter"""
from langchain.text_splitter import RecursiveCharacterTextSplitter
return RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)

157
tests/test_doc_tracker.py Normal file
View File

@ -0,0 +1,157 @@
import os
import time
import pytest
from .fixtures import *
from code_rag.doc_tracker import DocMetaData, DocumentTracker, calculate_file_hash
def doc_infos_are_equal(left, right):
"""Check to see if two doc_infos are the same"""
for k, v in left.items():
try:
print(v.to_dict(), right[k].to_dict(), v.to_dict() == right[k].to_dict())
if v != right[k]:
return False
except KeyError:
return False
return True
@pytest.fixture
def document_tracker(tracker_file):
"""Create a DocumentTracker instance"""
return DocumentTracker(tracking_file=tracker_file)
# Tests for DocumentTracker
def test_init_new_tracker(tracker_file):
"""Test creating a new tracker"""
tracker = DocumentTracker(tracking_file=tracker_file)
assert tracker.doc_info == {}
assert not os.path.exists(tracker_file)
def test_save_and_load_tracking_data(document_tracker, tracker_file):
"""Test saving and loading tracking data"""
# Add some data
update_time = time.time()
document_tracker.doc_info = {
"test.txt": DocMetaData(123456, "abcdef", update_time, ["1", "2"])
}
document_tracker._save_tracking_data()
# Check file exists
assert os.path.exists(tracker_file)
# Create a new tracker that should load the data
new_tracker = DocumentTracker(tracking_file=tracker_file)
assert doc_infos_are_equal(
new_tracker.doc_info,
{"test.txt": DocMetaData(123456, "abcdef", update_time, ["1", "2"])},
)
def test_calculate_file_hash(document_tracker, sample_docs):
"""Test hash calculation for a file"""
file_path = sample_docs[0]
hash1 = calculate_file_hash(file_path)
# Same content should yield same hash
hash2 = calculate_file_hash(file_path)
assert hash1 == hash2
# Different content should yield different hash
with open(file_path, "a") as f:
f.write("Additional content")
hash3 = calculate_file_hash(file_path)
assert hash1 != hash3
def test_get_changed_files_new(document_tracker, docs_dir, sample_docs):
"""Test detecting new files"""
changes = document_tracker.get_changed_files(docs_dir)
assert set(changes["new"]) == set(sample_docs)
assert changes["modified"] == []
assert changes["deleted"] == []
# Verify tracking was updated
for file_path in sample_docs:
assert file_path in document_tracker.doc_info
assert document_tracker.doc_info[file_path].chunk_ids == []
def test_get_changed_files_modified(document_tracker, docs_dir, sample_docs):
"""Test detecting modified files"""
# First scan to establish tracking
document_tracker.get_changed_files(docs_dir)
# Modify a file and wait to ensure timestamp difference
time.sleep(0.1)
with open(sample_docs[0], "a") as f:
f.write("Modified content")
# Detect changes
changes = document_tracker.get_changed_files(docs_dir)
assert changes["new"] == []
assert changes["modified"] == [sample_docs[0]]
assert changes["deleted"] == []
def test_get_changed_files_deleted(document_tracker, docs_dir, sample_docs):
"""Test detecting deleted files"""
# First scan to establish tracking
document_tracker.get_changed_files(docs_dir)
# Delete a file
os.remove(sample_docs[0])
# Detect changes
changes = document_tracker.get_changed_files(docs_dir)
assert changes["new"] == []
assert changes["modified"] == []
assert changes["deleted"] == [sample_docs[0]]
def test_update_chunk_mappings(document_tracker, sample_docs):
"""Test updating chunk mappings"""
file_path = sample_docs[0]
chunk_ids = ["chunk1", "chunk2", "chunk3"]
# First make sure the file is tracked
document_tracker.doc_info[file_path] = DocMetaData(
123,
"abc",
"2023-01-01",
[],
)
# Update chunk mappings
document_tracker.update_chunk_mappings(file_path, chunk_ids)
assert document_tracker.doc_info[file_path].chunk_ids == chunk_ids
def test_get_chunks_to_delete(document_tracker):
"""Test getting chunks to delete for deleted files"""
# Setup tracking data
document_tracker.doc_info = {
"file1.txt": DocMetaData(0, "abc", 0, ["chunk1", "chunk2"]),
"file2.txt": DocMetaData(0, "abc", 0, ["chunk3", "chunk4"]),
"file3.txt": DocMetaData(0, "abc", 0, ["chunk5"]),
}
# Test with one deleted file
chunks = document_tracker.get_chunks_to_delete(["file1.txt"])
assert set(chunks) == {"chunk1", "chunk2"}
# Verify file was removed from tracking
assert "file1.txt" not in document_tracker.doc_info
# Test with multiple deleted files
chunks = document_tracker.get_chunks_to_delete(["file2.txt", "file3.txt"])
assert set(chunks) == {"chunk3", "chunk4", "chunk5"}
# Verify files were removed from tracking
assert "file2.txt" not in document_tracker.doc_info
assert "file3.txt" not in document_tracker.doc_info

124
tests/test_rag.py Normal file
View File

@ -0,0 +1,124 @@
import os
import pytest
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from .fixtures import *
from .utility import *
from code_rag.rag import RAG
from code_rag.doc_tracker import DocumentTracker
@pytest.fixture
def rag_pipeline(docs_dir, db_dir, tracker_file):
"""Create a RAG instance"""
return RAG(docs_dir, db_dir, tracker_file)
# Tests for document processing
def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipeline):
"""Test processing documents into chunks with tracking"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
files = [
os.path.join(rag_pipeline.docs_dir, "doc1.txt"),
os.path.join(rag_pipeline.docs_dir, "doc2.txt"),
]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks, file_chunk_map = rag_pipeline.process_documents(files, text_splitter)
# Verify chunks were created
assert len(chunks) >= 2 # At least one chunk per document
tracker = rag_pipeline.tracker
# Verify chunk IDs were tracked
for file_path in files:
assert file_path in tracker.doc_info
assert "chunk_ids" in tracker.doc_info[file_path]
assert len(tracker.doc_info[file_path]["chunk_ids"]) > 0
# Verify metadata in chunks
for chunk in chunks:
assert "source" in chunk.metadata
assert "chunk_id" in chunk.metadata
assert chunk.metadata["source"] in files
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
"""Test creating a vector database"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
# Verify it was created
assert os.path.exists(rag_pipeline.db_dir)
assert vectorstore is not None
# Check the database has content
embeddings = OllamaEmbeddings(model="nomic-embed-text")
loaded_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
)
assert loaded_db._collection.count() > 0
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
"""Test updating a vector database with document changes"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database
rag_pipeline.create_vector_db(force_refresh=True)
# Get initial count
embeddings = OllamaEmbeddings(model="nomic-embed-text")
initial_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
)
initial_count = initial_db._collection.count()
# Make changes to documents
# Add a new document
create_test_document(
docs_dir, "newdoc.txt", "This is a brand new document for testing."
)
# Update the vector database
rag_pipeline.create_vector_db()
# Check the database has been updated
updated_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
)
assert updated_db._collection.count() > initial_count
# Final integration test - full RAG pipeline
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
"""Test the entire RAG pipeline from document processing to querying"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create a specific document with known content
test_content = "Python is a high-level programming language known for its readability and versatility."
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
# Create vector database
rag_pipeline.create_vector_db(force_refresh=True)
# Set up RAG
rag_chain = rag_pipeline.setup_rag(model_name="llama3.2")
# Query the system
query = "What is Python?"
response = rag_pipeline.query_rag(rag_chain, query)
# Check if response contains relevant information
# This is a soft test since the exact response will depend on the LLM
assert response.strip() != ""
assert "programming" in response.lower() or "language" in response.lower()

26
tests/utility.py Normal file
View File

@ -0,0 +1,26 @@
import time
import os
def create_test_document(base_path, filename, content):
"""Create a test document with specified content"""
file_path = os.path.join(base_path, filename)
# Create directory if it doesn't exist (for subdirectories)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
f.write(content)
return file_path
def modify_test_document(file_path, new_content):
"""Modify an existing test document"""
with open(file_path, "w") as f:
f.write(new_content)
# Ensure modification time changes
time.sleep(0.1)
def delete_test_document(file_path):
"""Delete a test document"""
if os.path.exists(file_path):
os.remove(file_path)