code_rag/tests/test_rag.py

169 lines
6.2 KiB
Python

import os
import shutil
import pytest
from langchain_chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from .fixtures import *
from code_rag.rag import RAG
from code_rag.doc_tracker import DocumentTracker
from code_rag.ollama_wrapper import OllamaWrapper
@pytest.fixture
def rag_pipeline(docs_dir, db_dir, tracker_file):
"""Create a RAG instance"""
return RAG(docs_dir, db_dir, tracker_file)
# Tests for document processing
def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipeline):
"""Test processing documents into chunks with tracking"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
files = [
os.path.join(rag_pipeline.docs_dir, "doc1.txt"),
os.path.join(rag_pipeline.docs_dir, "doc2.txt"),
os.path.join(rag_pipeline.docs_dir, "doc3.txt"),
]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks, file_chunk_map = rag_pipeline.process_documents(files, text_splitter)
# Verify chunks were created
assert len(chunks) >= 3 # At least one chunk per document
tracker = rag_pipeline.tracker
# Verify chunk IDs were tracked
for file_path in files:
assert file_path in tracker.doc_info
assert len(tracker.doc_info[file_path].chunk_ids) > 0
# Verify metadata in chunks
for chunk in chunks:
assert "source" in chunk.metadata
assert "chunk_id" in chunk.metadata
assert chunk.metadata["source"] in files
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
"""Test creating a vector database"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create files with different extensions
files = {
"test.py": "def hello():\n print('Hello, World!')",
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
"config.toml": "[package]\nname = 'test'",
"doc.md": "# Documentation\nThis is a test file."
}
for filename, content in files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
f.write(content)
# Create vector database with default extensions (should include all file types)
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
# Verify it was created
assert os.path.exists(rag_pipeline.db_dir)
assert vectorstore is not None
# Check the database has content from all file types
loaded_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
)
# Should have content from all files
assert loaded_db._collection.count() > 0
# Verify each file type is included
docs = loaded_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
for filename in files.keys():
assert filename in sources, f"File {filename} not found in vector store"
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
"""Test updating a vector database with document changes"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database with only Python files
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
# Get initial count
initial_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
)
initial_count = initial_db._collection.count()
# Make changes to documents
# Add files of different types
new_files = {
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
"lib.rs": "fn main() { println!('Hello'); }",
"config.toml": "[package]\nname = 'test'"
}
for filename, content in new_files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
f.write(content)
# Update the vector database to include all supported extensions
rag_pipeline.create_vector_db() # Use default extensions
# Check the database has been updated
updated_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
)
assert updated_db._collection.count() > initial_count
# Verify new files are included
docs = updated_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
for filename in new_files.keys():
assert filename in sources, f"File {filename} not found in vector store"
@pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
)
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
"""Test the entire RAG pipeline from document processing to querying"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create documents with mixed content types
test_files = {
"python_info.py": """# Python Information
def describe_python():
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
pass""",
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
}
for filename, content in test_files.items():
filepath = os.path.join(rag_pipeline.docs_dir, filename)
with open(filepath, "w") as f:
f.write(content)
# Create vector database with all default extensions
rag_pipeline.create_vector_db(force_refresh=True)
# Set up RAG
rag_chain = rag_pipeline.setup_rag()
# Query the system
query = "What is Python?"
response = rag_pipeline.query_rag(rag_chain, query)
# Check if response contains relevant information
# This is a soft test since the exact response will depend on the LLM
assert response.strip() != ""
assert "programming" in response.lower() or "language" in response.lower()