169 lines
6.2 KiB
Python
169 lines
6.2 KiB
Python
import os
|
|
import shutil
|
|
import pytest
|
|
from langchain_chroma import Chroma
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
from .fixtures import *
|
|
from code_rag.rag import RAG
|
|
from code_rag.doc_tracker import DocumentTracker
|
|
from code_rag.ollama_wrapper import OllamaWrapper
|
|
|
|
|
|
@pytest.fixture
|
|
def rag_pipeline(docs_dir, db_dir, tracker_file):
|
|
"""Create a RAG instance"""
|
|
return RAG(docs_dir, db_dir, tracker_file)
|
|
|
|
|
|
# Tests for document processing
|
|
def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipeline):
|
|
"""Test processing documents into chunks with tracking"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
files = [
|
|
os.path.join(rag_pipeline.docs_dir, "doc1.txt"),
|
|
os.path.join(rag_pipeline.docs_dir, "doc2.txt"),
|
|
os.path.join(rag_pipeline.docs_dir, "doc3.txt"),
|
|
]
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
|
|
chunks, file_chunk_map = rag_pipeline.process_documents(files, text_splitter)
|
|
|
|
# Verify chunks were created
|
|
assert len(chunks) >= 3 # At least one chunk per document
|
|
tracker = rag_pipeline.tracker
|
|
# Verify chunk IDs were tracked
|
|
for file_path in files:
|
|
assert file_path in tracker.doc_info
|
|
assert len(tracker.doc_info[file_path].chunk_ids) > 0
|
|
|
|
# Verify metadata in chunks
|
|
for chunk in chunks:
|
|
assert "source" in chunk.metadata
|
|
assert "chunk_id" in chunk.metadata
|
|
assert chunk.metadata["source"] in files
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
|
)
|
|
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
|
"""Test creating a vector database"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
|
|
# Create files with different extensions
|
|
files = {
|
|
"test.py": "def hello():\n print('Hello, World!')",
|
|
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
|
|
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
|
|
"config.toml": "[package]\nname = 'test'",
|
|
"doc.md": "# Documentation\nThis is a test file."
|
|
}
|
|
|
|
for filename, content in files.items():
|
|
filepath = os.path.join(docs_dir, filename)
|
|
with open(filepath, "w") as f:
|
|
f.write(content)
|
|
|
|
# Create vector database with default extensions (should include all file types)
|
|
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
|
|
|
# Verify it was created
|
|
assert os.path.exists(rag_pipeline.db_dir)
|
|
assert vectorstore is not None
|
|
|
|
# Check the database has content from all file types
|
|
loaded_db = Chroma(
|
|
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
|
)
|
|
# Should have content from all files
|
|
assert loaded_db._collection.count() > 0
|
|
|
|
# Verify each file type is included
|
|
docs = loaded_db._collection.get()
|
|
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
|
for filename in files.keys():
|
|
assert filename in sources, f"File {filename} not found in vector store"
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
|
)
|
|
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
|
"""Test updating a vector database with document changes"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
# Create initial vector database with only Python files
|
|
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
|
|
|
|
# Get initial count
|
|
initial_db = Chroma(
|
|
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
|
)
|
|
initial_count = initial_db._collection.count()
|
|
|
|
# Make changes to documents
|
|
# Add files of different types
|
|
new_files = {
|
|
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
|
|
"lib.rs": "fn main() { println!('Hello'); }",
|
|
"config.toml": "[package]\nname = 'test'"
|
|
}
|
|
|
|
for filename, content in new_files.items():
|
|
filepath = os.path.join(docs_dir, filename)
|
|
with open(filepath, "w") as f:
|
|
f.write(content)
|
|
|
|
# Update the vector database to include all supported extensions
|
|
rag_pipeline.create_vector_db() # Use default extensions
|
|
|
|
# Check the database has been updated
|
|
updated_db = Chroma(
|
|
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
|
)
|
|
assert updated_db._collection.count() > initial_count
|
|
|
|
# Verify new files are included
|
|
docs = updated_db._collection.get()
|
|
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
|
for filename in new_files.keys():
|
|
assert filename in sources, f"File {filename} not found in vector store"
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
|
)
|
|
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
|
"""Test the entire RAG pipeline from document processing to querying"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
|
|
# Create documents with mixed content types
|
|
test_files = {
|
|
"python_info.py": """# Python Information
|
|
def describe_python():
|
|
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
|
|
pass""",
|
|
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
|
|
}
|
|
|
|
for filename, content in test_files.items():
|
|
filepath = os.path.join(rag_pipeline.docs_dir, filename)
|
|
with open(filepath, "w") as f:
|
|
f.write(content)
|
|
|
|
# Create vector database with all default extensions
|
|
rag_pipeline.create_vector_db(force_refresh=True)
|
|
|
|
# Set up RAG
|
|
rag_chain = rag_pipeline.setup_rag()
|
|
|
|
# Query the system
|
|
query = "What is Python?"
|
|
response = rag_pipeline.query_rag(rag_chain, query)
|
|
|
|
# Check if response contains relevant information
|
|
# This is a soft test since the exact response will depend on the LLM
|
|
assert response.strip() != ""
|
|
assert "programming" in response.lower() or "language" in response.lower()
|