Initial commit
This commit is contained in:
commit
e0d0962fc5
6
README.md
Normal file
6
README.md
Normal file
@ -0,0 +1,6 @@
|
||||
# code_rag
|
||||
|
||||
This is intended to be a fully local Retrieval Augmented Generation tool for software development.
|
||||
This will be loadable as a neovim plugin, where we retrieve the current workspace files and add them to a persistent chroma vector store.
|
||||
This tool will expose an Ollama API to query locally running Ollama models with prompts augmented by relevant context from the current code.
|
||||
|
97
document_tracker.json
Normal file
97
document_tracker.json
Normal file
@ -0,0 +1,97 @@
|
||||
{
|
||||
"/tmp/tmpje6soo0x/documents/python_info.txt": {
|
||||
"chunk_ids": [
|
||||
"c0b04224-f72e-4f6e-bef1-067f1cf5c716"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpy2siaefv/documents/python_info.txt": {
|
||||
"chunk_ids": [
|
||||
"27dc4179-c908-47dd-8e56-2e2b3c403faf"
|
||||
]
|
||||
},
|
||||
"/tmp/tmp_fjaagdf/documents/python_info.txt": {
|
||||
"chunk_ids": [
|
||||
"3c147f7f-bd5a-456b-8bee-f9a43b6721c9"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpfbgc8zgm/documents/python_info.txt": {
|
||||
"chunk_ids": [
|
||||
"34eb0065-8a18-4144-890c-2e099d334b88"
|
||||
]
|
||||
},
|
||||
"/tmp/tmp8u9_vv_c/documents/python_info.txt": {
|
||||
"chunk_ids": [
|
||||
"6cd7655c-8fa2-437c-b8fc-5690c928b89e"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpm_quzgav/documents/python_info.txt": {
|
||||
"chunk_ids": [
|
||||
"9e452f02-fe93-4b65-b883-f585df25dd4d"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpjrncrqpe/documents/doc1.txt": {
|
||||
"chunk_ids": [
|
||||
"81cea007-7a54-4d19-869d-3643fd8cf257"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpjrncrqpe/documents/doc2.txt": {
|
||||
"chunk_ids": [
|
||||
"eb1bd307-e065-49a9-8d73-007403d13bd0"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpjrncrqpe/documents/doc3.txt": {
|
||||
"chunk_ids": [
|
||||
"3adccb38-28df-496d-b5a1-fa4555ea1fb5"
|
||||
]
|
||||
},
|
||||
"/tmp/tmptyee73q5/documents/doc1.txt": {
|
||||
"chunk_ids": [
|
||||
"319dc9c4-0a96-4c49-9306-c9c11e94b613"
|
||||
]
|
||||
},
|
||||
"/tmp/tmptyee73q5/documents/doc2.txt": {
|
||||
"chunk_ids": [
|
||||
"d20230ce-0753-4303-a9a2-d14379bb3d91"
|
||||
]
|
||||
},
|
||||
"/tmp/tmptyee73q5/documents/doc3.txt": {
|
||||
"chunk_ids": [
|
||||
"ceb9e3c7-c84f-49b4-8da4-ec96ac39f7e5"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpq7xsv0n6/documents/doc1.txt": {
|
||||
"chunk_ids": [
|
||||
"2a7d43bd-3f25-46b2-8f9c-b5108afa06df"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpq7xsv0n6/documents/doc2.txt": {
|
||||
"chunk_ids": [
|
||||
"6e38e121-14a6-4c07-9717-b7babb0fbb2f"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpq7xsv0n6/documents/doc3.txt": {
|
||||
"chunk_ids": [
|
||||
"cbc80fab-f4bc-47c4-a482-36afab9a649e"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpi80m0y6n/documents/doc1.txt": {
|
||||
"chunk_ids": [
|
||||
"b145e673-0a4c-4058-a6d1-e1237bd2d9af"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpi80m0y6n/documents/doc2.txt": {
|
||||
"chunk_ids": [
|
||||
"32082799-0089-4a3d-b4ad-ea3ff0084212"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpi80m0y6n/documents/doc3.txt": {
|
||||
"chunk_ids": [
|
||||
"0902fef4-ce95-4a73-9e0c-90c1581f80b2"
|
||||
]
|
||||
},
|
||||
"/tmp/tmpi80m0y6n/documents/python_info.txt": {
|
||||
"chunk_ids": [
|
||||
"009da108-76ba-4ef7-9c6b-f862f577ce23"
|
||||
]
|
||||
}
|
||||
}
|
4834
poetry.lock
generated
Normal file
4834
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[project]
|
||||
name = "code-rag"
|
||||
version = "0.1.0"
|
||||
description = "Simple RAG implementation for use with neovim"
|
||||
authors = [{ name = "Alex Selimov", email = "alex@alexselimov.com" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9,<4.0"
|
||||
dependencies = ["langchain (>=0.3.21,<0.4.0)", "ollama (>=0.4.7,<0.5.0)", "langchain-community (>=0.3.20,<0.4.0)", "langchain-ollama (>=0.2.3,<0.3.0)", "chromadb (>=0.6.3,<0.7.0)", "unstructured (>=0.17.2,<0.18.0)"]
|
||||
|
||||
[tool.poetry]
|
||||
packages = [{ include = "code_rag", from = "src" }]
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.3.5"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
0
src/code_rag/__init__.py
Normal file
0
src/code_rag/__init__.py
Normal file
BIN
src/code_rag/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
src/code_rag/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/code_rag/__pycache__/doc_tracker.cpython-313.pyc
Normal file
BIN
src/code_rag/__pycache__/doc_tracker.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/code_rag/__pycache__/rag.cpython-313.pyc
Normal file
BIN
src/code_rag/__pycache__/rag.cpython-313.pyc
Normal file
Binary file not shown.
174
src/code_rag/doc_tracker.py
Normal file
174
src/code_rag/doc_tracker.py
Normal file
@ -0,0 +1,174 @@
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
|
||||
class DocMetaData:
|
||||
"""
|
||||
Class that stores document meta data. Using this so we know that the MetaData is always
|
||||
initialized and can avoid a lot of checking whether keys exist or not
|
||||
"""
|
||||
|
||||
def __init__(self, mod_time, hash, last_updated, chunk_ids):
|
||||
self.mod_time = mod_time
|
||||
self.hash = hash
|
||||
self.last_updated = last_updated
|
||||
self.chunk_ids = chunk_ids
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, DocMetaData):
|
||||
return (
|
||||
self.mod_time == other.mod_time
|
||||
and self.hash == other.hash
|
||||
and self.last_updated == other.last_updated
|
||||
and self.chunk_ids == other.chunk_ids
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def update_chunks(self, chunks):
|
||||
self.chunk_ids = chunks
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"mod_time": self.mod_time,
|
||||
"hash": self.hash,
|
||||
"last_updated": self.last_updated,
|
||||
"chunk_ids": self.chunk_ids,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path):
|
||||
return cls(
|
||||
os.path.getmtime(file_path), calculate_file_hash(file_path), time.time(), []
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, input_dict):
|
||||
return cls(
|
||||
input_dict["mod_time"],
|
||||
input_dict["hash"],
|
||||
input_dict["last_updated"],
|
||||
input_dict["chunk_ids"],
|
||||
)
|
||||
|
||||
|
||||
class DocumentTracker:
|
||||
"""
|
||||
Tracks document changes using file hashes, modification times, and chunk IDs
|
||||
"""
|
||||
|
||||
def __init__(self, tracking_file):
|
||||
self.tracking_file = tracking_file
|
||||
self.doc_info = self._load_tracking_data()
|
||||
|
||||
def _load_tracking_data(self):
|
||||
"""Load existing tracking data if available"""
|
||||
doc_info = dict()
|
||||
if os.path.exists(self.tracking_file):
|
||||
with open(self.tracking_file, "r") as f:
|
||||
serialized = json.load(f)
|
||||
for k, v in serialized.items():
|
||||
doc_info[k] = DocMetaData.from_dict(v)
|
||||
|
||||
return doc_info
|
||||
|
||||
def _save_tracking_data(self):
|
||||
"""Save tracking data to file"""
|
||||
output = dict()
|
||||
for k, v in self.doc_info.items():
|
||||
output[k] = v.to_dict()
|
||||
with open(self.tracking_file, "w") as f:
|
||||
json.dump(output, f, indent=2)
|
||||
|
||||
def get_changed_files(self, directory, file_extension=".txt"):
|
||||
"""
|
||||
Detect new, modified, and deleted files
|
||||
Returns: dict with 'new', 'modified', and 'deleted' lists
|
||||
"""
|
||||
current_file_mod_times = {}
|
||||
for root, _, files in os.walk(directory):
|
||||
for file in files:
|
||||
if file.endswith(file_extension):
|
||||
file_path = os.path.join(root, file)
|
||||
mod_time = os.path.getmtime(file_path)
|
||||
current_file_mod_times[file_path] = mod_time
|
||||
|
||||
new_files = []
|
||||
modified_files = []
|
||||
|
||||
# Check for new or modified files
|
||||
for file_path, mod_time in current_file_mod_times.items():
|
||||
if file_path not in self.doc_info:
|
||||
new_files.append(file_path)
|
||||
elif mod_time > self.doc_info[file_path].mod_time:
|
||||
# Check if content actually changed using hash
|
||||
current_hash = calculate_file_hash(file_path)
|
||||
if current_hash != self.doc_info[file_path].hash:
|
||||
modified_files.append(file_path)
|
||||
self.doc_info[file_path].hash = current_hash
|
||||
self.doc_info[file_path].mod_time = mod_time
|
||||
|
||||
# Check for deleted files
|
||||
deleted_files = [f for f in self.doc_info if f not in current_file_mod_times]
|
||||
|
||||
# Update tracking information for new files
|
||||
for file_path in new_files:
|
||||
self.doc_info[file_path] = DocMetaData(
|
||||
current_file_mod_times[file_path],
|
||||
calculate_file_hash(file_path),
|
||||
time.time(),
|
||||
[], # Will store IDs of chunks derived from this document
|
||||
)
|
||||
|
||||
# Update last_updated for modified files
|
||||
for file_path in modified_files:
|
||||
self.doc_info[file_path].last_updated = datetime.now().isoformat()
|
||||
|
||||
self._save_tracking_data()
|
||||
|
||||
return {"new": new_files, "modified": modified_files, "deleted": deleted_files}
|
||||
|
||||
def update_chunk_mappings(self, file_path, chunk_ids):
|
||||
"""Store the chunk IDs associated with a document"""
|
||||
if file_path not in self.doc_info:
|
||||
self.doc_info[file_path] = DocMetaData.from_file(file_path)
|
||||
|
||||
self.doc_info[file_path].chunk_ids = chunk_ids
|
||||
self._save_tracking_data()
|
||||
|
||||
def get_chunks_to_delete(self, deleted_files):
|
||||
"""Get all chunk IDs associated with deleted files"""
|
||||
chunks_to_delete = []
|
||||
for file_path in deleted_files:
|
||||
if file_path in self.doc_info and self.doc_info[file_path].chunk_ids:
|
||||
chunks_to_delete.extend(self.doc_info[file_path].chunk_ids)
|
||||
# Remove the file from tracking after processing
|
||||
del self.doc_info[file_path]
|
||||
|
||||
self._save_tracking_data()
|
||||
return chunks_to_delete
|
||||
|
||||
def get_chunks_for_modified_files(self, modified_files):
|
||||
"""Get chunk IDs for modified files that need to be deleted before re-indexing"""
|
||||
chunks_to_delete = []
|
||||
for file_path in modified_files:
|
||||
if file_path in self.doc_info and self.doc_info[file_path].chunk_ids:
|
||||
chunks_to_delete.extend(self.doc_info[file_path].chunk_ids)
|
||||
# Clear the chunk IDs (will be updated with new ones)
|
||||
self.doc_info[file_path].chunk_ids = []
|
||||
|
||||
self._save_tracking_data()
|
||||
return chunks_to_delete
|
||||
|
||||
|
||||
def calculate_file_hash(file_path):
|
||||
"""Calculate MD5 hash of file contents"""
|
||||
hasher = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
buf = f.read(65536)
|
||||
while len(buf) > 0:
|
||||
hasher.update(buf)
|
||||
buf = f.read(65536)
|
||||
return hasher.hexdigest()
|
228
src/code_rag/rag.py
Normal file
228
src/code_rag/rag.py
Normal file
@ -0,0 +1,228 @@
|
||||
import os
|
||||
import uuid
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.llms import Ollama
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from code_rag.doc_tracker import DocumentTracker
|
||||
|
||||
|
||||
class RAG:
|
||||
def __init__(self, docs_dir, db_dir, tracker_file):
|
||||
self.docs_dir = docs_dir
|
||||
self.db_dir = db_dir
|
||||
self.tracker = DocumentTracker(tracker_file)
|
||||
|
||||
def process_documents(self, files, text_splitter):
|
||||
"""Process document files into chunks with tracking metadata"""
|
||||
all_chunks = []
|
||||
file_chunk_map = {}
|
||||
|
||||
for file_path in files:
|
||||
# Load the document
|
||||
loader = TextLoader(file_path)
|
||||
documents = loader.load()
|
||||
|
||||
# Add source metadata
|
||||
for doc in documents:
|
||||
doc.metadata["source"] = file_path
|
||||
doc.metadata["source_id"] = file_path # For easier identification
|
||||
|
||||
# Split the document
|
||||
chunks = text_splitter.split_documents(documents)
|
||||
|
||||
# Generate and track chunk IDs
|
||||
chunk_ids = []
|
||||
for chunk in chunks:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
chunk.metadata["chunk_id"] = chunk_id
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
# Store chunk mappings if tracker is provided
|
||||
if self.tracker:
|
||||
self.tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||
|
||||
file_chunk_map[file_path] = chunks
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
return all_chunks, file_chunk_map
|
||||
|
||||
def create_vector_db(self, extension=".txt", force_refresh=False):
|
||||
"""
|
||||
Create or update a vector database, with complete handling of changes
|
||||
"""
|
||||
# Text splitter configuration
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
)
|
||||
|
||||
# Create embeddings
|
||||
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||
|
||||
if force_refresh:
|
||||
print("Force refresh: Processing all documents")
|
||||
# Load all documents
|
||||
loader = DirectoryLoader(self.docs_dir, glob=f"**/*{extension}")
|
||||
all_documents = loader.load()
|
||||
|
||||
# Add unique IDs to each document
|
||||
for doc in all_documents:
|
||||
doc.metadata["source"] = os.path.abspath(doc.metadata["source"])
|
||||
doc.metadata["source_id"] = doc.metadata["source"]
|
||||
|
||||
# Split documents
|
||||
chunks = text_splitter.split_documents(all_documents)
|
||||
|
||||
# Add chunk IDs and update tracker
|
||||
file_chunk_map = {}
|
||||
for chunk in chunks:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
chunk.metadata["chunk_id"] = chunk_id
|
||||
|
||||
source = chunk.metadata["source"]
|
||||
if source not in file_chunk_map:
|
||||
file_chunk_map[source] = []
|
||||
|
||||
file_chunk_map[source].append(chunk_id)
|
||||
|
||||
# Update tracker with chunk mappings
|
||||
for file_path, chunk_ids in file_chunk_map.items():
|
||||
self.tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||
|
||||
print(
|
||||
f"Processing {len(all_documents)} documents with {len(chunks)} chunks"
|
||||
)
|
||||
|
||||
# Create new vector store
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
|
||||
)
|
||||
vectorstore.persist()
|
||||
print(f"Created new vector database at {self.db_dir}")
|
||||
|
||||
return vectorstore
|
||||
|
||||
# Get changes since last update
|
||||
changed_files = self.tracker.get_changed_files(self.docs_dir)
|
||||
|
||||
if not any(changed_files.values()):
|
||||
print("No document changes detected")
|
||||
# Load existing vector store if available
|
||||
if os.path.exists(self.db_dir):
|
||||
return Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=embeddings
|
||||
)
|
||||
else:
|
||||
print("No vector database exists. Creating from all documents...")
|
||||
return self.create_vector_db(force_refresh=True)
|
||||
|
||||
# Process changes
|
||||
print(
|
||||
f"Changes detected - New: {len(changed_files['new'])}, Modified: {len(changed_files['modified'])}, Deleted: {len(changed_files['deleted'])}"
|
||||
)
|
||||
|
||||
# Load existing vector store if it exists
|
||||
if os.path.exists(self.db_dir):
|
||||
vectorstore = Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=embeddings
|
||||
)
|
||||
|
||||
# 1. Handle deleted documents
|
||||
if changed_files["deleted"]:
|
||||
chunks_to_delete = self.tracker.get_chunks_to_delete(
|
||||
changed_files["deleted"]
|
||||
)
|
||||
if chunks_to_delete:
|
||||
print(
|
||||
f"Removing {len(chunks_to_delete)} chunks from deleted documents"
|
||||
)
|
||||
# Delete the chunks from vector store
|
||||
vectorstore._collection.delete(
|
||||
where={"chunk_id": {"$in": chunks_to_delete}}
|
||||
)
|
||||
|
||||
# 2. Handle modified documents (delete old chunks first)
|
||||
chunks_to_delete_modified = self.tracker.get_chunks_for_modified_files(
|
||||
changed_files["modified"]
|
||||
)
|
||||
if chunks_to_delete_modified:
|
||||
print(
|
||||
f"Removing {len(chunks_to_delete_modified)} chunks from modified documents"
|
||||
)
|
||||
vectorstore._collection.delete(
|
||||
where={"chunk_id": {"$in": chunks_to_delete_modified}}
|
||||
)
|
||||
|
||||
# 3. Process new and modified documents
|
||||
files_to_process = changed_files["new"] + changed_files["modified"]
|
||||
if files_to_process:
|
||||
chunks, _ = self.process_documents(
|
||||
files_to_process, text_splitter, self.tracker
|
||||
)
|
||||
print(f"Adding {len(chunks)} new chunks to the vector store")
|
||||
vectorstore.add_documents(chunks)
|
||||
else:
|
||||
# If no existing DB, create from all documents
|
||||
print("No existing vector database. Creating from all documents...")
|
||||
return self.create_vector_db(force_refresh=True)
|
||||
|
||||
# Persist changes
|
||||
vectorstore.persist()
|
||||
print(f"Vector database updated at {self.db_dir}")
|
||||
|
||||
return vectorstore
|
||||
|
||||
def setup_rag(self, model_name="llama3"):
|
||||
"""
|
||||
Set up the RAG system with an existing vector database
|
||||
"""
|
||||
# Load the embeddings
|
||||
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||
|
||||
# Load the vector store
|
||||
vectorstore = Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=embeddings
|
||||
)
|
||||
|
||||
# Create a retriever
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
|
||||
|
||||
# Set up the LLM
|
||||
llm = Ollama(model=model_name)
|
||||
|
||||
# Create a custom prompt template
|
||||
template = """
|
||||
Answer the question based on the context provided. If you don't know the answer,
|
||||
just say you don't know. Don't try to make up an answer.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["context", "question"], template=template
|
||||
)
|
||||
|
||||
# Create the RAG chain
|
||||
rag_chain = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type="stuff",
|
||||
retriever=retriever,
|
||||
chain_type_kwargs={"prompt": prompt},
|
||||
)
|
||||
|
||||
return rag_chain
|
||||
|
||||
def query_rag(self, rag_chain, query):
|
||||
"""
|
||||
Query the RAG system
|
||||
"""
|
||||
response = rag_chain.invoke({"query": query})
|
||||
return response["result"]
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/fixtures.cpython-313.pyc
Normal file
BIN
tests/__pycache__/fixtures.cpython-313.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_doc_tracker.cpython-313-pytest-8.3.5.pyc
Normal file
BIN
tests/__pycache__/test_doc_tracker.cpython-313-pytest-8.3.5.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_rag.cpython-313-pytest-8.3.5.pyc
Normal file
BIN
tests/__pycache__/test_rag.cpython-313-pytest-8.3.5.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/utility.cpython-313.pyc
Normal file
BIN
tests/__pycache__/utility.cpython-313.pyc
Normal file
Binary file not shown.
66
tests/fixtures.py
Normal file
66
tests/fixtures.py
Normal file
@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
import shutil
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for test files"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docs_dir(temp_dir):
|
||||
"""Create a temporary documents directory"""
|
||||
docs_dir = os.path.join(temp_dir, "documents")
|
||||
os.makedirs(docs_dir)
|
||||
yield docs_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_dir(temp_dir):
|
||||
"""Create a temporary vector database directory"""
|
||||
db_dir = os.path.join(temp_dir, "vector_db")
|
||||
os.makedirs(db_dir)
|
||||
yield db_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker_file(temp_dir):
|
||||
"""Create a temporary tracker file"""
|
||||
tracker_path = os.path.join(temp_dir, "test_tracker.json")
|
||||
yield tracker_path
|
||||
# Clean up after tests
|
||||
if os.path.exists(tracker_path):
|
||||
os.remove(tracker_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_docs(docs_dir):
|
||||
"""Create sample text documents for testing"""
|
||||
# Create a few sample documents
|
||||
doc1_path = os.path.join(docs_dir, "doc1.txt")
|
||||
doc2_path = os.path.join(docs_dir, "doc2.txt")
|
||||
doc3_path = os.path.join(docs_dir, "doc3.txt")
|
||||
|
||||
with open(doc1_path, "w") as f:
|
||||
f.write("This is a sample document about artificial intelligence. " * 10)
|
||||
|
||||
with open(doc2_path, "w") as f:
|
||||
f.write("This document discusses machine learning concepts. " * 10)
|
||||
|
||||
with open(doc3_path, "w") as f:
|
||||
f.write("Natural language processing is a field of AI. " * 10)
|
||||
|
||||
return [doc1_path, doc2_path, doc3_path]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_splitter():
|
||||
"""Create a mock text splitter"""
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
return RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
|
157
tests/test_doc_tracker.py
Normal file
157
tests/test_doc_tracker.py
Normal file
@ -0,0 +1,157 @@
|
||||
import os
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from .fixtures import *
|
||||
from code_rag.doc_tracker import DocMetaData, DocumentTracker, calculate_file_hash
|
||||
|
||||
|
||||
def doc_infos_are_equal(left, right):
|
||||
"""Check to see if two doc_infos are the same"""
|
||||
for k, v in left.items():
|
||||
try:
|
||||
print(v.to_dict(), right[k].to_dict(), v.to_dict() == right[k].to_dict())
|
||||
if v != right[k]:
|
||||
return False
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_tracker(tracker_file):
|
||||
"""Create a DocumentTracker instance"""
|
||||
return DocumentTracker(tracking_file=tracker_file)
|
||||
|
||||
|
||||
# Tests for DocumentTracker
|
||||
def test_init_new_tracker(tracker_file):
|
||||
"""Test creating a new tracker"""
|
||||
tracker = DocumentTracker(tracking_file=tracker_file)
|
||||
assert tracker.doc_info == {}
|
||||
assert not os.path.exists(tracker_file)
|
||||
|
||||
|
||||
def test_save_and_load_tracking_data(document_tracker, tracker_file):
|
||||
"""Test saving and loading tracking data"""
|
||||
# Add some data
|
||||
update_time = time.time()
|
||||
document_tracker.doc_info = {
|
||||
"test.txt": DocMetaData(123456, "abcdef", update_time, ["1", "2"])
|
||||
}
|
||||
document_tracker._save_tracking_data()
|
||||
|
||||
# Check file exists
|
||||
assert os.path.exists(tracker_file)
|
||||
|
||||
# Create a new tracker that should load the data
|
||||
new_tracker = DocumentTracker(tracking_file=tracker_file)
|
||||
assert doc_infos_are_equal(
|
||||
new_tracker.doc_info,
|
||||
{"test.txt": DocMetaData(123456, "abcdef", update_time, ["1", "2"])},
|
||||
)
|
||||
|
||||
|
||||
def test_calculate_file_hash(document_tracker, sample_docs):
|
||||
"""Test hash calculation for a file"""
|
||||
file_path = sample_docs[0]
|
||||
hash1 = calculate_file_hash(file_path)
|
||||
|
||||
# Same content should yield same hash
|
||||
hash2 = calculate_file_hash(file_path)
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different content should yield different hash
|
||||
with open(file_path, "a") as f:
|
||||
f.write("Additional content")
|
||||
|
||||
hash3 = calculate_file_hash(file_path)
|
||||
assert hash1 != hash3
|
||||
|
||||
|
||||
def test_get_changed_files_new(document_tracker, docs_dir, sample_docs):
|
||||
"""Test detecting new files"""
|
||||
changes = document_tracker.get_changed_files(docs_dir)
|
||||
assert set(changes["new"]) == set(sample_docs)
|
||||
assert changes["modified"] == []
|
||||
assert changes["deleted"] == []
|
||||
|
||||
# Verify tracking was updated
|
||||
for file_path in sample_docs:
|
||||
assert file_path in document_tracker.doc_info
|
||||
assert document_tracker.doc_info[file_path].chunk_ids == []
|
||||
|
||||
|
||||
def test_get_changed_files_modified(document_tracker, docs_dir, sample_docs):
|
||||
"""Test detecting modified files"""
|
||||
# First scan to establish tracking
|
||||
document_tracker.get_changed_files(docs_dir)
|
||||
|
||||
# Modify a file and wait to ensure timestamp difference
|
||||
time.sleep(0.1)
|
||||
with open(sample_docs[0], "a") as f:
|
||||
f.write("Modified content")
|
||||
|
||||
# Detect changes
|
||||
changes = document_tracker.get_changed_files(docs_dir)
|
||||
assert changes["new"] == []
|
||||
assert changes["modified"] == [sample_docs[0]]
|
||||
assert changes["deleted"] == []
|
||||
|
||||
|
||||
def test_get_changed_files_deleted(document_tracker, docs_dir, sample_docs):
|
||||
"""Test detecting deleted files"""
|
||||
# First scan to establish tracking
|
||||
document_tracker.get_changed_files(docs_dir)
|
||||
|
||||
# Delete a file
|
||||
os.remove(sample_docs[0])
|
||||
|
||||
# Detect changes
|
||||
changes = document_tracker.get_changed_files(docs_dir)
|
||||
assert changes["new"] == []
|
||||
assert changes["modified"] == []
|
||||
assert changes["deleted"] == [sample_docs[0]]
|
||||
|
||||
|
||||
def test_update_chunk_mappings(document_tracker, sample_docs):
|
||||
"""Test updating chunk mappings"""
|
||||
file_path = sample_docs[0]
|
||||
chunk_ids = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
# First make sure the file is tracked
|
||||
document_tracker.doc_info[file_path] = DocMetaData(
|
||||
123,
|
||||
"abc",
|
||||
"2023-01-01",
|
||||
[],
|
||||
)
|
||||
|
||||
# Update chunk mappings
|
||||
document_tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||
assert document_tracker.doc_info[file_path].chunk_ids == chunk_ids
|
||||
|
||||
|
||||
def test_get_chunks_to_delete(document_tracker):
|
||||
"""Test getting chunks to delete for deleted files"""
|
||||
# Setup tracking data
|
||||
document_tracker.doc_info = {
|
||||
"file1.txt": DocMetaData(0, "abc", 0, ["chunk1", "chunk2"]),
|
||||
"file2.txt": DocMetaData(0, "abc", 0, ["chunk3", "chunk4"]),
|
||||
"file3.txt": DocMetaData(0, "abc", 0, ["chunk5"]),
|
||||
}
|
||||
|
||||
# Test with one deleted file
|
||||
chunks = document_tracker.get_chunks_to_delete(["file1.txt"])
|
||||
assert set(chunks) == {"chunk1", "chunk2"}
|
||||
|
||||
# Verify file was removed from tracking
|
||||
assert "file1.txt" not in document_tracker.doc_info
|
||||
|
||||
# Test with multiple deleted files
|
||||
chunks = document_tracker.get_chunks_to_delete(["file2.txt", "file3.txt"])
|
||||
assert set(chunks) == {"chunk3", "chunk4", "chunk5"}
|
||||
|
||||
# Verify files were removed from tracking
|
||||
assert "file2.txt" not in document_tracker.doc_info
|
||||
assert "file3.txt" not in document_tracker.doc_info
|
124
tests/test_rag.py
Normal file
124
tests/test_rag.py
Normal file
@ -0,0 +1,124 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from .fixtures import *
|
||||
from .utility import *
|
||||
from code_rag.rag import RAG
|
||||
from code_rag.doc_tracker import DocumentTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_pipeline(docs_dir, db_dir, tracker_file):
|
||||
"""Create a RAG instance"""
|
||||
return RAG(docs_dir, db_dir, tracker_file)
|
||||
|
||||
|
||||
# Tests for document processing
|
||||
def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipeline):
|
||||
"""Test processing documents into chunks with tracking"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
files = [
|
||||
os.path.join(rag_pipeline.docs_dir, "doc1.txt"),
|
||||
os.path.join(rag_pipeline.docs_dir, "doc2.txt"),
|
||||
]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
chunks, file_chunk_map = rag_pipeline.process_documents(files, text_splitter)
|
||||
|
||||
# Verify chunks were created
|
||||
assert len(chunks) >= 2 # At least one chunk per document
|
||||
tracker = rag_pipeline.tracker
|
||||
# Verify chunk IDs were tracked
|
||||
for file_path in files:
|
||||
assert file_path in tracker.doc_info
|
||||
assert "chunk_ids" in tracker.doc_info[file_path]
|
||||
assert len(tracker.doc_info[file_path]["chunk_ids"]) > 0
|
||||
|
||||
# Verify metadata in chunks
|
||||
for chunk in chunks:
|
||||
assert "source" in chunk.metadata
|
||||
assert "chunk_id" in chunk.metadata
|
||||
assert chunk.metadata["source"] in files
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||
)
|
||||
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test creating a vector database"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
# Create initial vector database
|
||||
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
||||
|
||||
# Verify it was created
|
||||
assert os.path.exists(rag_pipeline.db_dir)
|
||||
assert vectorstore is not None
|
||||
# Check the database has content
|
||||
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||
loaded_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
||||
)
|
||||
assert loaded_db._collection.count() > 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||
)
|
||||
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test updating a vector database with document changes"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
# Create initial vector database
|
||||
rag_pipeline.create_vector_db(force_refresh=True)
|
||||
|
||||
# Get initial count
|
||||
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||
initial_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
||||
)
|
||||
initial_count = initial_db._collection.count()
|
||||
|
||||
# Make changes to documents
|
||||
# Add a new document
|
||||
create_test_document(
|
||||
docs_dir, "newdoc.txt", "This is a brand new document for testing."
|
||||
)
|
||||
|
||||
# Update the vector database
|
||||
rag_pipeline.create_vector_db()
|
||||
|
||||
# Check the database has been updated
|
||||
updated_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
||||
)
|
||||
assert updated_db._collection.count() > initial_count
|
||||
|
||||
|
||||
# Final integration test - full RAG pipeline
|
||||
@pytest.mark.skipif(
|
||||
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||
)
|
||||
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test the entire RAG pipeline from document processing to querying"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
# Create a specific document with known content
|
||||
test_content = "Python is a high-level programming language known for its readability and versatility."
|
||||
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
|
||||
|
||||
# Create vector database
|
||||
rag_pipeline.create_vector_db(force_refresh=True)
|
||||
|
||||
# Set up RAG
|
||||
rag_chain = rag_pipeline.setup_rag(model_name="llama3.2")
|
||||
|
||||
# Query the system
|
||||
query = "What is Python?"
|
||||
response = rag_pipeline.query_rag(rag_chain, query)
|
||||
|
||||
# Check if response contains relevant information
|
||||
# This is a soft test since the exact response will depend on the LLM
|
||||
assert response.strip() != ""
|
||||
assert "programming" in response.lower() or "language" in response.lower()
|
26
tests/utility.py
Normal file
26
tests/utility.py
Normal file
@ -0,0 +1,26 @@
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
def create_test_document(base_path, filename, content):
|
||||
"""Create a test document with specified content"""
|
||||
file_path = os.path.join(base_path, filename)
|
||||
# Create directory if it doesn't exist (for subdirectories)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, "w") as f:
|
||||
f.write(content)
|
||||
return file_path
|
||||
|
||||
|
||||
def modify_test_document(file_path, new_content):
|
||||
"""Modify an existing test document"""
|
||||
with open(file_path, "w") as f:
|
||||
f.write(new_content)
|
||||
# Ensure modification time changes
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def delete_test_document(file_path):
|
||||
"""Delete a test document"""
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
Loading…
x
Reference in New Issue
Block a user