Initial commit
This commit is contained in:
commit
e0d0962fc5
6
README.md
Normal file
6
README.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# code_rag
|
||||||
|
|
||||||
|
This is intended to be a fully local Retrieval Augmented Generation tool for software development.
|
||||||
|
This will be loadable as a neovim plugin, where we retrieve the current workspace files and add them to a persistent chroma vector store.
|
||||||
|
This tool will expose an Ollama API to query locally running Ollama models with prompts augmented by relevant context from the current code.
|
||||||
|
|
97
document_tracker.json
Normal file
97
document_tracker.json
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
{
|
||||||
|
"/tmp/tmpje6soo0x/documents/python_info.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"c0b04224-f72e-4f6e-bef1-067f1cf5c716"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpy2siaefv/documents/python_info.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"27dc4179-c908-47dd-8e56-2e2b3c403faf"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmp_fjaagdf/documents/python_info.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"3c147f7f-bd5a-456b-8bee-f9a43b6721c9"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpfbgc8zgm/documents/python_info.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"34eb0065-8a18-4144-890c-2e099d334b88"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmp8u9_vv_c/documents/python_info.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"6cd7655c-8fa2-437c-b8fc-5690c928b89e"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpm_quzgav/documents/python_info.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"9e452f02-fe93-4b65-b883-f585df25dd4d"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpjrncrqpe/documents/doc1.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"81cea007-7a54-4d19-869d-3643fd8cf257"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpjrncrqpe/documents/doc2.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"eb1bd307-e065-49a9-8d73-007403d13bd0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpjrncrqpe/documents/doc3.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"3adccb38-28df-496d-b5a1-fa4555ea1fb5"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmptyee73q5/documents/doc1.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"319dc9c4-0a96-4c49-9306-c9c11e94b613"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmptyee73q5/documents/doc2.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"d20230ce-0753-4303-a9a2-d14379bb3d91"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmptyee73q5/documents/doc3.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"ceb9e3c7-c84f-49b4-8da4-ec96ac39f7e5"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpq7xsv0n6/documents/doc1.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"2a7d43bd-3f25-46b2-8f9c-b5108afa06df"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpq7xsv0n6/documents/doc2.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"6e38e121-14a6-4c07-9717-b7babb0fbb2f"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpq7xsv0n6/documents/doc3.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"cbc80fab-f4bc-47c4-a482-36afab9a649e"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpi80m0y6n/documents/doc1.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"b145e673-0a4c-4058-a6d1-e1237bd2d9af"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpi80m0y6n/documents/doc2.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"32082799-0089-4a3d-b4ad-ea3ff0084212"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpi80m0y6n/documents/doc3.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"0902fef4-ce95-4a73-9e0c-90c1581f80b2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"/tmp/tmpi80m0y6n/documents/python_info.txt": {
|
||||||
|
"chunk_ids": [
|
||||||
|
"009da108-76ba-4ef7-9c6b-f862f577ce23"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
4834
poetry.lock
generated
Normal file
4834
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
[project]
|
||||||
|
name = "code-rag"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Simple RAG implementation for use with neovim"
|
||||||
|
authors = [{ name = "Alex Selimov", email = "alex@alexselimov.com" }]
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9,<4.0"
|
||||||
|
dependencies = ["langchain (>=0.3.21,<0.4.0)", "ollama (>=0.4.7,<0.5.0)", "langchain-community (>=0.3.20,<0.4.0)", "langchain-ollama (>=0.2.3,<0.3.0)", "chromadb (>=0.6.3,<0.7.0)", "unstructured (>=0.17.2,<0.18.0)"]
|
||||||
|
|
||||||
|
[tool.poetry]
|
||||||
|
packages = [{ include = "code_rag", from = "src" }]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
pytest = "^8.3.5"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
0
src/code_rag/__init__.py
Normal file
0
src/code_rag/__init__.py
Normal file
BIN
src/code_rag/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
src/code_rag/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/code_rag/__pycache__/doc_tracker.cpython-313.pyc
Normal file
BIN
src/code_rag/__pycache__/doc_tracker.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/code_rag/__pycache__/rag.cpython-313.pyc
Normal file
BIN
src/code_rag/__pycache__/rag.cpython-313.pyc
Normal file
Binary file not shown.
174
src/code_rag/doc_tracker.py
Normal file
174
src/code_rag/doc_tracker.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class DocMetaData:
|
||||||
|
"""
|
||||||
|
Class that stores document meta data. Using this so we know that the MetaData is always
|
||||||
|
initialized and can avoid a lot of checking whether keys exist or not
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mod_time, hash, last_updated, chunk_ids):
|
||||||
|
self.mod_time = mod_time
|
||||||
|
self.hash = hash
|
||||||
|
self.last_updated = last_updated
|
||||||
|
self.chunk_ids = chunk_ids
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, DocMetaData):
|
||||||
|
return (
|
||||||
|
self.mod_time == other.mod_time
|
||||||
|
and self.hash == other.hash
|
||||||
|
and self.last_updated == other.last_updated
|
||||||
|
and self.chunk_ids == other.chunk_ids
|
||||||
|
)
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def update_chunks(self, chunks):
|
||||||
|
self.chunk_ids = chunks
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"mod_time": self.mod_time,
|
||||||
|
"hash": self.hash,
|
||||||
|
"last_updated": self.last_updated,
|
||||||
|
"chunk_ids": self.chunk_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, file_path):
|
||||||
|
return cls(
|
||||||
|
os.path.getmtime(file_path), calculate_file_hash(file_path), time.time(), []
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, input_dict):
|
||||||
|
return cls(
|
||||||
|
input_dict["mod_time"],
|
||||||
|
input_dict["hash"],
|
||||||
|
input_dict["last_updated"],
|
||||||
|
input_dict["chunk_ids"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentTracker:
|
||||||
|
"""
|
||||||
|
Tracks document changes using file hashes, modification times, and chunk IDs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tracking_file):
|
||||||
|
self.tracking_file = tracking_file
|
||||||
|
self.doc_info = self._load_tracking_data()
|
||||||
|
|
||||||
|
def _load_tracking_data(self):
|
||||||
|
"""Load existing tracking data if available"""
|
||||||
|
doc_info = dict()
|
||||||
|
if os.path.exists(self.tracking_file):
|
||||||
|
with open(self.tracking_file, "r") as f:
|
||||||
|
serialized = json.load(f)
|
||||||
|
for k, v in serialized.items():
|
||||||
|
doc_info[k] = DocMetaData.from_dict(v)
|
||||||
|
|
||||||
|
return doc_info
|
||||||
|
|
||||||
|
def _save_tracking_data(self):
|
||||||
|
"""Save tracking data to file"""
|
||||||
|
output = dict()
|
||||||
|
for k, v in self.doc_info.items():
|
||||||
|
output[k] = v.to_dict()
|
||||||
|
with open(self.tracking_file, "w") as f:
|
||||||
|
json.dump(output, f, indent=2)
|
||||||
|
|
||||||
|
def get_changed_files(self, directory, file_extension=".txt"):
|
||||||
|
"""
|
||||||
|
Detect new, modified, and deleted files
|
||||||
|
Returns: dict with 'new', 'modified', and 'deleted' lists
|
||||||
|
"""
|
||||||
|
current_file_mod_times = {}
|
||||||
|
for root, _, files in os.walk(directory):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(file_extension):
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
mod_time = os.path.getmtime(file_path)
|
||||||
|
current_file_mod_times[file_path] = mod_time
|
||||||
|
|
||||||
|
new_files = []
|
||||||
|
modified_files = []
|
||||||
|
|
||||||
|
# Check for new or modified files
|
||||||
|
for file_path, mod_time in current_file_mod_times.items():
|
||||||
|
if file_path not in self.doc_info:
|
||||||
|
new_files.append(file_path)
|
||||||
|
elif mod_time > self.doc_info[file_path].mod_time:
|
||||||
|
# Check if content actually changed using hash
|
||||||
|
current_hash = calculate_file_hash(file_path)
|
||||||
|
if current_hash != self.doc_info[file_path].hash:
|
||||||
|
modified_files.append(file_path)
|
||||||
|
self.doc_info[file_path].hash = current_hash
|
||||||
|
self.doc_info[file_path].mod_time = mod_time
|
||||||
|
|
||||||
|
# Check for deleted files
|
||||||
|
deleted_files = [f for f in self.doc_info if f not in current_file_mod_times]
|
||||||
|
|
||||||
|
# Update tracking information for new files
|
||||||
|
for file_path in new_files:
|
||||||
|
self.doc_info[file_path] = DocMetaData(
|
||||||
|
current_file_mod_times[file_path],
|
||||||
|
calculate_file_hash(file_path),
|
||||||
|
time.time(),
|
||||||
|
[], # Will store IDs of chunks derived from this document
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update last_updated for modified files
|
||||||
|
for file_path in modified_files:
|
||||||
|
self.doc_info[file_path].last_updated = datetime.now().isoformat()
|
||||||
|
|
||||||
|
self._save_tracking_data()
|
||||||
|
|
||||||
|
return {"new": new_files, "modified": modified_files, "deleted": deleted_files}
|
||||||
|
|
||||||
|
def update_chunk_mappings(self, file_path, chunk_ids):
|
||||||
|
"""Store the chunk IDs associated with a document"""
|
||||||
|
if file_path not in self.doc_info:
|
||||||
|
self.doc_info[file_path] = DocMetaData.from_file(file_path)
|
||||||
|
|
||||||
|
self.doc_info[file_path].chunk_ids = chunk_ids
|
||||||
|
self._save_tracking_data()
|
||||||
|
|
||||||
|
def get_chunks_to_delete(self, deleted_files):
|
||||||
|
"""Get all chunk IDs associated with deleted files"""
|
||||||
|
chunks_to_delete = []
|
||||||
|
for file_path in deleted_files:
|
||||||
|
if file_path in self.doc_info and self.doc_info[file_path].chunk_ids:
|
||||||
|
chunks_to_delete.extend(self.doc_info[file_path].chunk_ids)
|
||||||
|
# Remove the file from tracking after processing
|
||||||
|
del self.doc_info[file_path]
|
||||||
|
|
||||||
|
self._save_tracking_data()
|
||||||
|
return chunks_to_delete
|
||||||
|
|
||||||
|
def get_chunks_for_modified_files(self, modified_files):
|
||||||
|
"""Get chunk IDs for modified files that need to be deleted before re-indexing"""
|
||||||
|
chunks_to_delete = []
|
||||||
|
for file_path in modified_files:
|
||||||
|
if file_path in self.doc_info and self.doc_info[file_path].chunk_ids:
|
||||||
|
chunks_to_delete.extend(self.doc_info[file_path].chunk_ids)
|
||||||
|
# Clear the chunk IDs (will be updated with new ones)
|
||||||
|
self.doc_info[file_path].chunk_ids = []
|
||||||
|
|
||||||
|
self._save_tracking_data()
|
||||||
|
return chunks_to_delete
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_file_hash(file_path):
|
||||||
|
"""Calculate MD5 hash of file contents"""
|
||||||
|
hasher = hashlib.md5()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
buf = f.read(65536)
|
||||||
|
while len(buf) > 0:
|
||||||
|
hasher.update(buf)
|
||||||
|
buf = f.read(65536)
|
||||||
|
return hasher.hexdigest()
|
228
src/code_rag/rag.py
Normal file
228
src/code_rag/rag.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
from langchain_community.llms import Ollama
|
||||||
|
from langchain.chains import RetrievalQA
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
from code_rag.doc_tracker import DocumentTracker
|
||||||
|
|
||||||
|
|
||||||
|
class RAG:
|
||||||
|
def __init__(self, docs_dir, db_dir, tracker_file):
|
||||||
|
self.docs_dir = docs_dir
|
||||||
|
self.db_dir = db_dir
|
||||||
|
self.tracker = DocumentTracker(tracker_file)
|
||||||
|
|
||||||
|
def process_documents(self, files, text_splitter):
|
||||||
|
"""Process document files into chunks with tracking metadata"""
|
||||||
|
all_chunks = []
|
||||||
|
file_chunk_map = {}
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
# Load the document
|
||||||
|
loader = TextLoader(file_path)
|
||||||
|
documents = loader.load()
|
||||||
|
|
||||||
|
# Add source metadata
|
||||||
|
for doc in documents:
|
||||||
|
doc.metadata["source"] = file_path
|
||||||
|
doc.metadata["source_id"] = file_path # For easier identification
|
||||||
|
|
||||||
|
# Split the document
|
||||||
|
chunks = text_splitter.split_documents(documents)
|
||||||
|
|
||||||
|
# Generate and track chunk IDs
|
||||||
|
chunk_ids = []
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_id = str(uuid.uuid4())
|
||||||
|
chunk.metadata["chunk_id"] = chunk_id
|
||||||
|
chunk_ids.append(chunk_id)
|
||||||
|
|
||||||
|
# Store chunk mappings if tracker is provided
|
||||||
|
if self.tracker:
|
||||||
|
self.tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||||
|
|
||||||
|
file_chunk_map[file_path] = chunks
|
||||||
|
all_chunks.extend(chunks)
|
||||||
|
|
||||||
|
return all_chunks, file_chunk_map
|
||||||
|
|
||||||
|
def create_vector_db(self, extension=".txt", force_refresh=False):
|
||||||
|
"""
|
||||||
|
Create or update a vector database, with complete handling of changes
|
||||||
|
"""
|
||||||
|
# Text splitter configuration
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000, chunk_overlap=200
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create embeddings
|
||||||
|
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||||
|
|
||||||
|
if force_refresh:
|
||||||
|
print("Force refresh: Processing all documents")
|
||||||
|
# Load all documents
|
||||||
|
loader = DirectoryLoader(self.docs_dir, glob=f"**/*{extension}")
|
||||||
|
all_documents = loader.load()
|
||||||
|
|
||||||
|
# Add unique IDs to each document
|
||||||
|
for doc in all_documents:
|
||||||
|
doc.metadata["source"] = os.path.abspath(doc.metadata["source"])
|
||||||
|
doc.metadata["source_id"] = doc.metadata["source"]
|
||||||
|
|
||||||
|
# Split documents
|
||||||
|
chunks = text_splitter.split_documents(all_documents)
|
||||||
|
|
||||||
|
# Add chunk IDs and update tracker
|
||||||
|
file_chunk_map = {}
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_id = str(uuid.uuid4())
|
||||||
|
chunk.metadata["chunk_id"] = chunk_id
|
||||||
|
|
||||||
|
source = chunk.metadata["source"]
|
||||||
|
if source not in file_chunk_map:
|
||||||
|
file_chunk_map[source] = []
|
||||||
|
|
||||||
|
file_chunk_map[source].append(chunk_id)
|
||||||
|
|
||||||
|
# Update tracker with chunk mappings
|
||||||
|
for file_path, chunk_ids in file_chunk_map.items():
|
||||||
|
self.tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Processing {len(all_documents)} documents with {len(chunks)} chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new vector store
|
||||||
|
vectorstore = Chroma.from_documents(
|
||||||
|
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
|
||||||
|
)
|
||||||
|
vectorstore.persist()
|
||||||
|
print(f"Created new vector database at {self.db_dir}")
|
||||||
|
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
# Get changes since last update
|
||||||
|
changed_files = self.tracker.get_changed_files(self.docs_dir)
|
||||||
|
|
||||||
|
if not any(changed_files.values()):
|
||||||
|
print("No document changes detected")
|
||||||
|
# Load existing vector store if available
|
||||||
|
if os.path.exists(self.db_dir):
|
||||||
|
return Chroma(
|
||||||
|
persist_directory=self.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("No vector database exists. Creating from all documents...")
|
||||||
|
return self.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
|
# Process changes
|
||||||
|
print(
|
||||||
|
f"Changes detected - New: {len(changed_files['new'])}, Modified: {len(changed_files['modified'])}, Deleted: {len(changed_files['deleted'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load existing vector store if it exists
|
||||||
|
if os.path.exists(self.db_dir):
|
||||||
|
vectorstore = Chroma(
|
||||||
|
persist_directory=self.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Handle deleted documents
|
||||||
|
if changed_files["deleted"]:
|
||||||
|
chunks_to_delete = self.tracker.get_chunks_to_delete(
|
||||||
|
changed_files["deleted"]
|
||||||
|
)
|
||||||
|
if chunks_to_delete:
|
||||||
|
print(
|
||||||
|
f"Removing {len(chunks_to_delete)} chunks from deleted documents"
|
||||||
|
)
|
||||||
|
# Delete the chunks from vector store
|
||||||
|
vectorstore._collection.delete(
|
||||||
|
where={"chunk_id": {"$in": chunks_to_delete}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Handle modified documents (delete old chunks first)
|
||||||
|
chunks_to_delete_modified = self.tracker.get_chunks_for_modified_files(
|
||||||
|
changed_files["modified"]
|
||||||
|
)
|
||||||
|
if chunks_to_delete_modified:
|
||||||
|
print(
|
||||||
|
f"Removing {len(chunks_to_delete_modified)} chunks from modified documents"
|
||||||
|
)
|
||||||
|
vectorstore._collection.delete(
|
||||||
|
where={"chunk_id": {"$in": chunks_to_delete_modified}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Process new and modified documents
|
||||||
|
files_to_process = changed_files["new"] + changed_files["modified"]
|
||||||
|
if files_to_process:
|
||||||
|
chunks, _ = self.process_documents(
|
||||||
|
files_to_process, text_splitter, self.tracker
|
||||||
|
)
|
||||||
|
print(f"Adding {len(chunks)} new chunks to the vector store")
|
||||||
|
vectorstore.add_documents(chunks)
|
||||||
|
else:
|
||||||
|
# If no existing DB, create from all documents
|
||||||
|
print("No existing vector database. Creating from all documents...")
|
||||||
|
return self.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
|
# Persist changes
|
||||||
|
vectorstore.persist()
|
||||||
|
print(f"Vector database updated at {self.db_dir}")
|
||||||
|
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
def setup_rag(self, model_name="llama3"):
|
||||||
|
"""
|
||||||
|
Set up the RAG system with an existing vector database
|
||||||
|
"""
|
||||||
|
# Load the embeddings
|
||||||
|
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||||
|
|
||||||
|
# Load the vector store
|
||||||
|
vectorstore = Chroma(
|
||||||
|
persist_directory=self.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a retriever
|
||||||
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
|
||||||
|
|
||||||
|
# Set up the LLM
|
||||||
|
llm = Ollama(model=model_name)
|
||||||
|
|
||||||
|
# Create a custom prompt template
|
||||||
|
template = """
|
||||||
|
Answer the question based on the context provided. If you don't know the answer,
|
||||||
|
just say you don't know. Don't try to make up an answer.
|
||||||
|
|
||||||
|
Context: {context}
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Answer:
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["context", "question"], template=template
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the RAG chain
|
||||||
|
rag_chain = RetrievalQA.from_chain_type(
|
||||||
|
llm=llm,
|
||||||
|
chain_type="stuff",
|
||||||
|
retriever=retriever,
|
||||||
|
chain_type_kwargs={"prompt": prompt},
|
||||||
|
)
|
||||||
|
|
||||||
|
return rag_chain
|
||||||
|
|
||||||
|
def query_rag(self, rag_chain, query):
|
||||||
|
"""
|
||||||
|
Query the RAG system
|
||||||
|
"""
|
||||||
|
response = rag_chain.invoke({"query": query})
|
||||||
|
return response["result"]
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/fixtures.cpython-313.pyc
Normal file
BIN
tests/__pycache__/fixtures.cpython-313.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_doc_tracker.cpython-313-pytest-8.3.5.pyc
Normal file
BIN
tests/__pycache__/test_doc_tracker.cpython-313-pytest-8.3.5.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_rag.cpython-313-pytest-8.3.5.pyc
Normal file
BIN
tests/__pycache__/test_rag.cpython-313-pytest-8.3.5.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/utility.cpython-313.pyc
Normal file
BIN
tests/__pycache__/utility.cpython-313.pyc
Normal file
Binary file not shown.
66
tests/fixtures.py
Normal file
66
tests/fixtures.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import pytest
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
"""Create a temporary directory for test files"""
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
yield temp_dir
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def docs_dir(temp_dir):
|
||||||
|
"""Create a temporary documents directory"""
|
||||||
|
docs_dir = os.path.join(temp_dir, "documents")
|
||||||
|
os.makedirs(docs_dir)
|
||||||
|
yield docs_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_dir(temp_dir):
|
||||||
|
"""Create a temporary vector database directory"""
|
||||||
|
db_dir = os.path.join(temp_dir, "vector_db")
|
||||||
|
os.makedirs(db_dir)
|
||||||
|
yield db_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tracker_file(temp_dir):
|
||||||
|
"""Create a temporary tracker file"""
|
||||||
|
tracker_path = os.path.join(temp_dir, "test_tracker.json")
|
||||||
|
yield tracker_path
|
||||||
|
# Clean up after tests
|
||||||
|
if os.path.exists(tracker_path):
|
||||||
|
os.remove(tracker_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_docs(docs_dir):
|
||||||
|
"""Create sample text documents for testing"""
|
||||||
|
# Create a few sample documents
|
||||||
|
doc1_path = os.path.join(docs_dir, "doc1.txt")
|
||||||
|
doc2_path = os.path.join(docs_dir, "doc2.txt")
|
||||||
|
doc3_path = os.path.join(docs_dir, "doc3.txt")
|
||||||
|
|
||||||
|
with open(doc1_path, "w") as f:
|
||||||
|
f.write("This is a sample document about artificial intelligence. " * 10)
|
||||||
|
|
||||||
|
with open(doc2_path, "w") as f:
|
||||||
|
f.write("This document discusses machine learning concepts. " * 10)
|
||||||
|
|
||||||
|
with open(doc3_path, "w") as f:
|
||||||
|
f.write("Natural language processing is a field of AI. " * 10)
|
||||||
|
|
||||||
|
return [doc1_path, doc2_path, doc3_path]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text_splitter():
|
||||||
|
"""Create a mock text splitter"""
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
return RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
|
157
tests/test_doc_tracker.py
Normal file
157
tests/test_doc_tracker.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .fixtures import *
|
||||||
|
from code_rag.doc_tracker import DocMetaData, DocumentTracker, calculate_file_hash
|
||||||
|
|
||||||
|
|
||||||
|
def doc_infos_are_equal(left, right):
|
||||||
|
"""Check to see if two doc_infos are the same"""
|
||||||
|
for k, v in left.items():
|
||||||
|
try:
|
||||||
|
print(v.to_dict(), right[k].to_dict(), v.to_dict() == right[k].to_dict())
|
||||||
|
if v != right[k]:
|
||||||
|
return False
|
||||||
|
except KeyError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def document_tracker(tracker_file):
|
||||||
|
"""Create a DocumentTracker instance"""
|
||||||
|
return DocumentTracker(tracking_file=tracker_file)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for DocumentTracker
|
||||||
|
def test_init_new_tracker(tracker_file):
|
||||||
|
"""Test creating a new tracker"""
|
||||||
|
tracker = DocumentTracker(tracking_file=tracker_file)
|
||||||
|
assert tracker.doc_info == {}
|
||||||
|
assert not os.path.exists(tracker_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_tracking_data(document_tracker, tracker_file):
|
||||||
|
"""Test saving and loading tracking data"""
|
||||||
|
# Add some data
|
||||||
|
update_time = time.time()
|
||||||
|
document_tracker.doc_info = {
|
||||||
|
"test.txt": DocMetaData(123456, "abcdef", update_time, ["1", "2"])
|
||||||
|
}
|
||||||
|
document_tracker._save_tracking_data()
|
||||||
|
|
||||||
|
# Check file exists
|
||||||
|
assert os.path.exists(tracker_file)
|
||||||
|
|
||||||
|
# Create a new tracker that should load the data
|
||||||
|
new_tracker = DocumentTracker(tracking_file=tracker_file)
|
||||||
|
assert doc_infos_are_equal(
|
||||||
|
new_tracker.doc_info,
|
||||||
|
{"test.txt": DocMetaData(123456, "abcdef", update_time, ["1", "2"])},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_file_hash(document_tracker, sample_docs):
|
||||||
|
"""Test hash calculation for a file"""
|
||||||
|
file_path = sample_docs[0]
|
||||||
|
hash1 = calculate_file_hash(file_path)
|
||||||
|
|
||||||
|
# Same content should yield same hash
|
||||||
|
hash2 = calculate_file_hash(file_path)
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
# Different content should yield different hash
|
||||||
|
with open(file_path, "a") as f:
|
||||||
|
f.write("Additional content")
|
||||||
|
|
||||||
|
hash3 = calculate_file_hash(file_path)
|
||||||
|
assert hash1 != hash3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_changed_files_new(document_tracker, docs_dir, sample_docs):
|
||||||
|
"""Test detecting new files"""
|
||||||
|
changes = document_tracker.get_changed_files(docs_dir)
|
||||||
|
assert set(changes["new"]) == set(sample_docs)
|
||||||
|
assert changes["modified"] == []
|
||||||
|
assert changes["deleted"] == []
|
||||||
|
|
||||||
|
# Verify tracking was updated
|
||||||
|
for file_path in sample_docs:
|
||||||
|
assert file_path in document_tracker.doc_info
|
||||||
|
assert document_tracker.doc_info[file_path].chunk_ids == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_changed_files_modified(document_tracker, docs_dir, sample_docs):
|
||||||
|
"""Test detecting modified files"""
|
||||||
|
# First scan to establish tracking
|
||||||
|
document_tracker.get_changed_files(docs_dir)
|
||||||
|
|
||||||
|
# Modify a file and wait to ensure timestamp difference
|
||||||
|
time.sleep(0.1)
|
||||||
|
with open(sample_docs[0], "a") as f:
|
||||||
|
f.write("Modified content")
|
||||||
|
|
||||||
|
# Detect changes
|
||||||
|
changes = document_tracker.get_changed_files(docs_dir)
|
||||||
|
assert changes["new"] == []
|
||||||
|
assert changes["modified"] == [sample_docs[0]]
|
||||||
|
assert changes["deleted"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_changed_files_deleted(document_tracker, docs_dir, sample_docs):
|
||||||
|
"""Test detecting deleted files"""
|
||||||
|
# First scan to establish tracking
|
||||||
|
document_tracker.get_changed_files(docs_dir)
|
||||||
|
|
||||||
|
# Delete a file
|
||||||
|
os.remove(sample_docs[0])
|
||||||
|
|
||||||
|
# Detect changes
|
||||||
|
changes = document_tracker.get_changed_files(docs_dir)
|
||||||
|
assert changes["new"] == []
|
||||||
|
assert changes["modified"] == []
|
||||||
|
assert changes["deleted"] == [sample_docs[0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_chunk_mappings(document_tracker, sample_docs):
|
||||||
|
"""Test updating chunk mappings"""
|
||||||
|
file_path = sample_docs[0]
|
||||||
|
chunk_ids = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
|
# First make sure the file is tracked
|
||||||
|
document_tracker.doc_info[file_path] = DocMetaData(
|
||||||
|
123,
|
||||||
|
"abc",
|
||||||
|
"2023-01-01",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update chunk mappings
|
||||||
|
document_tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||||
|
assert document_tracker.doc_info[file_path].chunk_ids == chunk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_chunks_to_delete(document_tracker):
|
||||||
|
"""Test getting chunks to delete for deleted files"""
|
||||||
|
# Setup tracking data
|
||||||
|
document_tracker.doc_info = {
|
||||||
|
"file1.txt": DocMetaData(0, "abc", 0, ["chunk1", "chunk2"]),
|
||||||
|
"file2.txt": DocMetaData(0, "abc", 0, ["chunk3", "chunk4"]),
|
||||||
|
"file3.txt": DocMetaData(0, "abc", 0, ["chunk5"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test with one deleted file
|
||||||
|
chunks = document_tracker.get_chunks_to_delete(["file1.txt"])
|
||||||
|
assert set(chunks) == {"chunk1", "chunk2"}
|
||||||
|
|
||||||
|
# Verify file was removed from tracking
|
||||||
|
assert "file1.txt" not in document_tracker.doc_info
|
||||||
|
|
||||||
|
# Test with multiple deleted files
|
||||||
|
chunks = document_tracker.get_chunks_to_delete(["file2.txt", "file3.txt"])
|
||||||
|
assert set(chunks) == {"chunk3", "chunk4", "chunk5"}
|
||||||
|
|
||||||
|
# Verify files were removed from tracking
|
||||||
|
assert "file2.txt" not in document_tracker.doc_info
|
||||||
|
assert "file3.txt" not in document_tracker.doc_info
|
124
tests/test_rag.py
Normal file
124
tests/test_rag.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
from .fixtures import *
|
||||||
|
from .utility import *
|
||||||
|
from code_rag.rag import RAG
|
||||||
|
from code_rag.doc_tracker import DocumentTracker
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rag_pipeline(docs_dir, db_dir, tracker_file):
|
||||||
|
"""Create a RAG instance"""
|
||||||
|
return RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for document processing
|
||||||
|
def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipeline):
|
||||||
|
"""Test processing documents into chunks with tracking"""
|
||||||
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
files = [
|
||||||
|
os.path.join(rag_pipeline.docs_dir, "doc1.txt"),
|
||||||
|
os.path.join(rag_pipeline.docs_dir, "doc2.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
||||||
|
|
||||||
|
chunks, file_chunk_map = rag_pipeline.process_documents(files, text_splitter)
|
||||||
|
|
||||||
|
# Verify chunks were created
|
||||||
|
assert len(chunks) >= 2 # At least one chunk per document
|
||||||
|
tracker = rag_pipeline.tracker
|
||||||
|
# Verify chunk IDs were tracked
|
||||||
|
for file_path in files:
|
||||||
|
assert file_path in tracker.doc_info
|
||||||
|
assert "chunk_ids" in tracker.doc_info[file_path]
|
||||||
|
assert len(tracker.doc_info[file_path]["chunk_ids"]) > 0
|
||||||
|
|
||||||
|
# Verify metadata in chunks
|
||||||
|
for chunk in chunks:
|
||||||
|
assert "source" in chunk.metadata
|
||||||
|
assert "chunk_id" in chunk.metadata
|
||||||
|
assert chunk.metadata["source"] in files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||||
|
)
|
||||||
|
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
|
"""Test creating a vector database"""
|
||||||
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
# Create initial vector database
|
||||||
|
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
|
# Verify it was created
|
||||||
|
assert os.path.exists(rag_pipeline.db_dir)
|
||||||
|
assert vectorstore is not None
|
||||||
|
# Check the database has content
|
||||||
|
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||||
|
loaded_db = Chroma(
|
||||||
|
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
assert loaded_db._collection.count() > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||||
|
)
|
||||||
|
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
|
"""Test updating a vector database with document changes"""
|
||||||
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
# Create initial vector database
|
||||||
|
rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
|
# Get initial count
|
||||||
|
embeddings = OllamaEmbeddings(model="nomic-embed-text")
|
||||||
|
initial_db = Chroma(
|
||||||
|
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
initial_count = initial_db._collection.count()
|
||||||
|
|
||||||
|
# Make changes to documents
|
||||||
|
# Add a new document
|
||||||
|
create_test_document(
|
||||||
|
docs_dir, "newdoc.txt", "This is a brand new document for testing."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the vector database
|
||||||
|
rag_pipeline.create_vector_db()
|
||||||
|
|
||||||
|
# Check the database has been updated
|
||||||
|
updated_db = Chroma(
|
||||||
|
persist_directory=rag_pipeline.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
assert updated_db._collection.count() > initial_count
|
||||||
|
|
||||||
|
|
||||||
|
# Final integration test - full RAG pipeline
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||||
|
)
|
||||||
|
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
|
"""Test the entire RAG pipeline from document processing to querying"""
|
||||||
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
# Create a specific document with known content
|
||||||
|
test_content = "Python is a high-level programming language known for its readability and versatility."
|
||||||
|
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
|
||||||
|
|
||||||
|
# Create vector database
|
||||||
|
rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
|
# Set up RAG
|
||||||
|
rag_chain = rag_pipeline.setup_rag(model_name="llama3.2")
|
||||||
|
|
||||||
|
# Query the system
|
||||||
|
query = "What is Python?"
|
||||||
|
response = rag_pipeline.query_rag(rag_chain, query)
|
||||||
|
|
||||||
|
# Check if response contains relevant information
|
||||||
|
# This is a soft test since the exact response will depend on the LLM
|
||||||
|
assert response.strip() != ""
|
||||||
|
assert "programming" in response.lower() or "language" in response.lower()
|
26
tests/utility.py
Normal file
26
tests/utility.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_document(base_path, filename, content):
|
||||||
|
"""Create a test document with specified content"""
|
||||||
|
file_path = os.path.join(base_path, filename)
|
||||||
|
# Create directory if it doesn't exist (for subdirectories)
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def modify_test_document(file_path, new_content):
|
||||||
|
"""Modify an existing test document"""
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
f.write(new_content)
|
||||||
|
# Ensure modification time changes
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_test_document(file_path):
|
||||||
|
"""Delete a test document"""
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
Loading…
x
Reference in New Issue
Block a user