125 lines
4.4 KiB
Python
125 lines
4.4 KiB
Python
import os
|
|
import pytest
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_ollama import OllamaEmbeddings
|
|
from langchain_community.vectorstores import Chroma
|
|
from .fixtures import *
|
|
from .utility import *
|
|
from code_rag.rag import RAG
|
|
from code_rag.doc_tracker import DocumentTracker
|
|
|
|
|
|
@pytest.fixture
|
|
def rag_pipeline(docs_dir, db_dir, tracker_file):
|
|
"""Create a RAG instance"""
|
|
return RAG(docs_dir, db_dir, tracker_file)
|
|
|
|
|
|
# Tests for document processing
|
|
def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipeline):
|
|
"""Test processing documents into chunks with tracking"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
files = [
|
|
os.path.join(rag_pipeline.docs_dir, "doc1.txt"),
|
|
os.path.join(rag_pipeline.docs_dir, "doc2.txt"),
|
|
]
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
|
|
chunks, file_chunk_map = rag_pipeline.process_documents(files, text_splitter)
|
|
|
|
# Verify chunks were created
|
|
assert len(chunks) >= 2 # At least one chunk per document
|
|
tracker = rag_pipeline.tracker
|
|
# Verify chunk IDs were tracked
|
|
for file_path in files:
|
|
assert file_path in tracker.doc_info
|
|
assert "chunk_ids" in tracker.doc_info[file_path]
|
|
assert len(tracker.doc_info[file_path]["chunk_ids"]) > 0
|
|
|
|
# Verify metadata in chunks
|
|
for chunk in chunks:
|
|
assert "source" in chunk.metadata
|
|
assert "chunk_id" in chunk.metadata
|
|
assert chunk.metadata["source"] in files
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
|
)
|
|
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
|
"""Test creating a vector database"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
# Create initial vector database
|
|
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
|
|
|
# Verify it was created
|
|
assert os.path.exists(rag_pipeline.db_dir)
|
|
assert vectorstore is not None
|
|
# Check the database has content
|
|
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
|
loaded_db = Chroma(
|
|
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
|
)
|
|
assert loaded_db._collection.count() > 0
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
|
)
|
|
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
|
"""Test updating a vector database with document changes"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
# Create initial vector database
|
|
rag_pipeline.create_vector_db(force_refresh=True)
|
|
|
|
# Get initial count
|
|
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
|
initial_db = Chroma(
|
|
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
|
)
|
|
initial_count = initial_db._collection.count()
|
|
|
|
# Make changes to documents
|
|
# Add a new document
|
|
create_test_document(
|
|
docs_dir, "newdoc.txt", "This is a brand new document for testing."
|
|
)
|
|
|
|
# Update the vector database
|
|
rag_pipeline.create_vector_db()
|
|
|
|
# Check the database has been updated
|
|
updated_db = Chroma(
|
|
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
|
)
|
|
assert updated_db._collection.count() > initial_count
|
|
|
|
|
|
# Final integration test - full RAG pipeline
|
|
@pytest.mark.skipif(
|
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
|
)
|
|
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
|
"""Test the entire RAG pipeline from document processing to querying"""
|
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
|
# Create a specific document with known content
|
|
test_content = "Python is a high-level programming language known for its readability and versatility."
|
|
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
|
|
|
|
# Create vector database
|
|
rag_pipeline.create_vector_db(force_refresh=True)
|
|
|
|
# Set up RAG
|
|
rag_chain = rag_pipeline.setup_rag(model_name="llama3.2")
|
|
|
|
# Query the system
|
|
query = "What is Python?"
|
|
response = rag_pipeline.query_rag(rag_chain, query)
|
|
|
|
# Check if response contains relevant information
|
|
# This is a soft test since the exact response will depend on the LLM
|
|
assert response.strip() != ""
|
|
assert "programming" in response.lower() or "language" in response.lower()
|