code_rag/tests/test_rag.py
2025-03-21 10:09:07 -04:00

125 lines
4.4 KiB
Python

import os
import pytest
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from .fixtures import *
from .utility import *
from code_rag.rag import RAG
from code_rag.doc_tracker import DocumentTracker
@pytest.fixture
def rag_pipeline(docs_dir, db_dir, tracker_file):
"""Create a RAG instance"""
return RAG(docs_dir, db_dir, tracker_file)
# Tests for document processing
def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipeline):
"""Test processing documents into chunks with tracking"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
files = [
os.path.join(rag_pipeline.docs_dir, "doc1.txt"),
os.path.join(rag_pipeline.docs_dir, "doc2.txt"),
]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks, file_chunk_map = rag_pipeline.process_documents(files, text_splitter)
# Verify chunks were created
assert len(chunks) >= 2 # At least one chunk per document
tracker = rag_pipeline.tracker
# Verify chunk IDs were tracked
for file_path in files:
assert file_path in tracker.doc_info
assert "chunk_ids" in tracker.doc_info[file_path]
assert len(tracker.doc_info[file_path]["chunk_ids"]) > 0
# Verify metadata in chunks
for chunk in chunks:
assert "source" in chunk.metadata
assert "chunk_id" in chunk.metadata
assert chunk.metadata["source"] in files
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
"""Test creating a vector database"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
# Verify it was created
assert os.path.exists(rag_pipeline.db_dir)
assert vectorstore is not None
# Check the database has content
embeddings = OllamaEmbeddings(model="nomic-embed-text")
loaded_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
)
assert loaded_db._collection.count() > 0
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
"""Test updating a vector database with document changes"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database
rag_pipeline.create_vector_db(force_refresh=True)
# Get initial count
embeddings = OllamaEmbeddings(model="nomic-embed-text")
initial_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
)
initial_count = initial_db._collection.count()
# Make changes to documents
# Add a new document
create_test_document(
docs_dir, "newdoc.txt", "This is a brand new document for testing."
)
# Update the vector database
rag_pipeline.create_vector_db()
# Check the database has been updated
updated_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
)
assert updated_db._collection.count() > initial_count
# Final integration test - full RAG pipeline
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
"""Test the entire RAG pipeline from document processing to querying"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create a specific document with known content
test_content = "Python is a high-level programming language known for its readability and versatility."
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
# Create vector database
rag_pipeline.create_vector_db(force_refresh=True)
# Set up RAG
rag_chain = rag_pipeline.setup_rag(model_name="llama3.2")
# Query the system
query = "What is Python?"
response = rag_pipeline.query_rag(rag_chain, query)
# Check if response contains relevant information
# This is a soft test since the exact response will depend on the LLM
assert response.strip() != ""
assert "programming" in response.lower() or "language" in response.lower()