Compare commits
No commits in common. "550e4b78e248519856c564fcae284784bd09aa99" and "f63dc7148ac6c62bef8218f5597696d255af9996" have entirely different histories.
550e4b78e2
...
f63dc7148a
1368
poetry.lock
generated
1368
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,29 +1,27 @@
|
||||
[project]
|
||||
name = "code-rag"
|
||||
version = "0.1.0"
|
||||
description = "Simple RAG implementation"
|
||||
authors = [
|
||||
{ name = "Alex Selimov", email = "alex@alexselimov.com" }
|
||||
]
|
||||
description = "Simple RAG implementation for use with neovim"
|
||||
authors = [{ name = "Alex Selimov", email = "alex@alexselimov.com" }]
|
||||
readme = "README.md"
|
||||
requires-python = "^3.11"
|
||||
requires-python = ">=3.9,<4.0"
|
||||
dependencies = [
|
||||
"langchain (>=0.3.21,<0.4.0)",
|
||||
"ollama (>=0.4.7,<0.5.0)",
|
||||
"langchain-community (>=0.3.20,<0.4.0)",
|
||||
"langchain-ollama (>=0.2.3,<0.3.0)",
|
||||
"chromadb (>=0.4.0,<0.6.0)",
|
||||
"unstructured (>=0.17.2,<0.18.0)",
|
||||
"langchain-chroma (>=0.1.0,<0.2.0)"
|
||||
]
|
||||
|
||||
[tool.poetry]
|
||||
packages = [{ include = "code_rag", from = "src" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
langchain = "*"
|
||||
langchain-community = "*"
|
||||
langchain-core = "*"
|
||||
langchain-chroma = "*"
|
||||
langchain-ollama = "*"
|
||||
chromadb = "*"
|
||||
ollama = "*"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.3.5"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
@ -1,12 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
import glob
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from code_rag.doc_tracker import DocumentTracker
|
||||
from code_rag.ollama_wrapper import OllamaWrapper
|
||||
@ -63,21 +61,10 @@ class RAG:
|
||||
|
||||
return all_chunks, file_chunk_map
|
||||
|
||||
def create_vector_db(self, extensions=None, force_refresh=False):
|
||||
def create_vector_db(self, extension=".txt", force_refresh=False):
|
||||
"""
|
||||
Create or update a vector database, with complete handling of changes.
|
||||
|
||||
Args:
|
||||
extensions (list[str], optional): List of file extensions to include.
|
||||
If None, defaults to common programming languages.
|
||||
force_refresh (bool): Whether to force a complete refresh of the vector database.
|
||||
Create or update a vector database, with complete handling of changes
|
||||
"""
|
||||
# Set default extensions for common programming languages if none provided
|
||||
if extensions is None:
|
||||
extensions = default_file_extensions()
|
||||
elif isinstance(extensions, str):
|
||||
extensions = [extensions]
|
||||
|
||||
# Text splitter configuration
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
@ -88,76 +75,121 @@ class RAG:
|
||||
embeddings = self.ollama.embeddings
|
||||
print("after embedding")
|
||||
|
||||
# Load or create vector store
|
||||
if os.path.exists(self.db_dir) and not force_refresh:
|
||||
print("Loading existing vector store")
|
||||
vectorstore = Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=embeddings
|
||||
)
|
||||
else:
|
||||
print("Creating new vector store")
|
||||
vectorstore = Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=embeddings
|
||||
)
|
||||
|
||||
# Find all files that match the extensions
|
||||
all_files = set()
|
||||
for ext in extensions:
|
||||
files = glob.glob(os.path.join(self.docs_dir, f"**/*{ext}"), recursive=True)
|
||||
all_files.update(files)
|
||||
|
||||
if not all_files:
|
||||
print("No documents found to process")
|
||||
return vectorstore
|
||||
|
||||
# Process all files
|
||||
all_documents = []
|
||||
for file_path in all_files:
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
doc = Document(
|
||||
page_content=content,
|
||||
metadata={
|
||||
"source": os.path.abspath(file_path),
|
||||
"source_id": os.path.abspath(file_path),
|
||||
},
|
||||
)
|
||||
all_documents.append(doc)
|
||||
print(f"Successfully loaded {file_path}")
|
||||
except Exception as e:
|
||||
print(f"Error loading file {file_path}: {str(e)}")
|
||||
|
||||
if not all_documents:
|
||||
print("No documents could be loaded")
|
||||
return vectorstore
|
||||
|
||||
# Split documents
|
||||
chunks = text_splitter.split_documents(all_documents)
|
||||
|
||||
if force_refresh:
|
||||
# Create new vector store from scratch
|
||||
print("Force refresh: Processing all documents")
|
||||
# Load all documents
|
||||
loader = DirectoryLoader(self.docs_dir, glob=f"**/*{extension}")
|
||||
all_documents = loader.load()
|
||||
|
||||
if not all_documents:
|
||||
print("No documents found to process")
|
||||
# Create an empty vector store
|
||||
vectorstore = Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=embeddings
|
||||
)
|
||||
return vectorstore
|
||||
|
||||
# Add unique IDs to each document
|
||||
for doc in all_documents:
|
||||
doc.metadata["source"] = os.path.abspath(doc.metadata["source"])
|
||||
doc.metadata["source_id"] = doc.metadata["source"]
|
||||
|
||||
# Split documents
|
||||
chunks = text_splitter.split_documents(all_documents)
|
||||
|
||||
# Add chunk IDs and update tracker
|
||||
file_chunk_map = {}
|
||||
for chunk in chunks:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
chunk.metadata["chunk_id"] = chunk_id
|
||||
|
||||
source = chunk.metadata["source"]
|
||||
if source not in file_chunk_map:
|
||||
file_chunk_map[source] = []
|
||||
|
||||
file_chunk_map[source].append(chunk_id)
|
||||
|
||||
# Update tracker with chunk mappings
|
||||
for file_path, chunk_ids in file_chunk_map.items():
|
||||
self.tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||
|
||||
print(
|
||||
f"Processing {len(all_documents)} documents with {len(chunks)} chunks"
|
||||
)
|
||||
|
||||
# Create new vector store
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
|
||||
)
|
||||
return vectorstore
|
||||
|
||||
# Get changes since last update
|
||||
changed_files = self.tracker.get_changed_files(self.docs_dir)
|
||||
|
||||
if not any(changed_files.values()):
|
||||
print("No document changes detected")
|
||||
# Load existing vector store if available
|
||||
if os.path.exists(self.db_dir):
|
||||
return Chroma(
|
||||
persist_directory=self.db_dir,
|
||||
embedding_function=self.ollama.embeddings,
|
||||
)
|
||||
else:
|
||||
print("No vector database exists. Creating from all documents...")
|
||||
return self.create_vector_db(force_refresh=True)
|
||||
|
||||
# Process changes
|
||||
print(
|
||||
f"Changes detected - New: {len(changed_files['new'])}, Modified: {len(changed_files['modified'])}, Deleted: {len(changed_files['deleted'])}"
|
||||
)
|
||||
|
||||
# Load existing vector store if it exists
|
||||
if os.path.exists(self.db_dir):
|
||||
vectorstore = Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
|
||||
)
|
||||
|
||||
# 1. Handle deleted documents
|
||||
if changed_files["deleted"]:
|
||||
chunks_to_delete = self.tracker.get_chunks_to_delete(
|
||||
changed_files["deleted"]
|
||||
)
|
||||
if chunks_to_delete:
|
||||
print(
|
||||
f"Removing {len(chunks_to_delete)} chunks from deleted documents"
|
||||
)
|
||||
# Delete the chunks from vector store
|
||||
vectorstore._collection.delete(
|
||||
where={"chunk_id": {"$in": chunks_to_delete}}
|
||||
)
|
||||
|
||||
# 2. Handle modified documents (delete old chunks first)
|
||||
chunks_to_delete_modified = self.tracker.get_chunks_for_modified_files(
|
||||
changed_files["modified"]
|
||||
)
|
||||
if chunks_to_delete_modified:
|
||||
print(
|
||||
f"Removing {len(chunks_to_delete_modified)} chunks from modified documents"
|
||||
)
|
||||
vectorstore._collection.delete(
|
||||
where={"chunk_id": {"$in": chunks_to_delete_modified}}
|
||||
)
|
||||
|
||||
# 3. Process new and modified documents
|
||||
files_to_process = changed_files["new"] + changed_files["modified"]
|
||||
if files_to_process:
|
||||
chunks, _ = self.process_documents(files_to_process, text_splitter)
|
||||
print(f"Adding {len(chunks)} new chunks to the vector store")
|
||||
vectorstore.add_documents(chunks)
|
||||
else:
|
||||
# Update existing vector store
|
||||
# First, get all existing document IDs
|
||||
existing_docs = vectorstore._collection.get()
|
||||
existing_sources = {m["source"] for m in existing_docs["metadatas"]}
|
||||
# If no existing DB, create from all documents
|
||||
print("No existing vector database. Creating from all documents...")
|
||||
return self.create_vector_db(force_refresh=True)
|
||||
|
||||
# Find new documents
|
||||
new_chunks = [
|
||||
chunk
|
||||
for chunk in chunks
|
||||
if chunk.metadata["source"] not in existing_sources
|
||||
]
|
||||
# Persist changes
|
||||
vectorstore.persist()
|
||||
print(f"Vector database updated at {self.db_dir}")
|
||||
|
||||
if new_chunks:
|
||||
print(f"Adding {len(new_chunks)} new chunks to vector store")
|
||||
vectorstore.add_documents(new_chunks)
|
||||
|
||||
print(f"Vector store updated with {len(chunks)} total chunks")
|
||||
return vectorstore
|
||||
|
||||
def setup_rag(self):
|
||||
@ -210,136 +242,3 @@ class RAG:
|
||||
"""
|
||||
response = rag_chain.invoke({"query": query})
|
||||
return response["result"]
|
||||
|
||||
|
||||
def default_file_extensions():
|
||||
"""Return the default file extensions representing common plain text and code files"""
|
||||
return [
|
||||
# Python
|
||||
".py",
|
||||
".pyi",
|
||||
".pyx",
|
||||
".pyc",
|
||||
".pyd",
|
||||
".pyw",
|
||||
# C/C++
|
||||
".c",
|
||||
".cpp",
|
||||
".cc",
|
||||
".cxx",
|
||||
".h",
|
||||
".hpp",
|
||||
".hxx",
|
||||
".inc",
|
||||
".inl",
|
||||
".ipp",
|
||||
# Rust
|
||||
".rs",
|
||||
".rlib",
|
||||
".rmeta",
|
||||
# Java
|
||||
".java",
|
||||
".jsp",
|
||||
".jav",
|
||||
".jar",
|
||||
".class",
|
||||
".kt",
|
||||
".kts",
|
||||
".groovy",
|
||||
# Web
|
||||
".html",
|
||||
".htm",
|
||||
".css",
|
||||
".scss",
|
||||
".sass",
|
||||
".less",
|
||||
".js",
|
||||
".jsx",
|
||||
".ts",
|
||||
".tsx",
|
||||
".vue",
|
||||
".svelte",
|
||||
# Fortran
|
||||
".f",
|
||||
".for",
|
||||
".f90",
|
||||
".f95",
|
||||
".f03",
|
||||
".f08",
|
||||
# Go
|
||||
".go",
|
||||
".mod",
|
||||
# Ruby
|
||||
".rb",
|
||||
".rbw",
|
||||
".rake",
|
||||
".gemspec",
|
||||
# PHP
|
||||
".php",
|
||||
".phtml",
|
||||
".php3",
|
||||
".php4",
|
||||
".php5",
|
||||
".phps",
|
||||
# C#
|
||||
".cs",
|
||||
".csx",
|
||||
".vb",
|
||||
# Swift
|
||||
".swift",
|
||||
".swiftmodule",
|
||||
# Shell/Scripts
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".ps1",
|
||||
".bat",
|
||||
".cmd",
|
||||
# Scala
|
||||
".scala",
|
||||
".sc",
|
||||
# Haskell
|
||||
".hs",
|
||||
".lhs",
|
||||
".hsc",
|
||||
# Lua
|
||||
".lua",
|
||||
".luac",
|
||||
# R
|
||||
".r",
|
||||
".rmd",
|
||||
".rds",
|
||||
# Perl
|
||||
".pl",
|
||||
".pm",
|
||||
".t",
|
||||
# Documentation
|
||||
".txt",
|
||||
".md",
|
||||
".rst",
|
||||
".adoc",
|
||||
".wiki",
|
||||
# Build/Config
|
||||
".toml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".json",
|
||||
".xml",
|
||||
".ini",
|
||||
".conf",
|
||||
".cfg",
|
||||
# SQL
|
||||
".sql",
|
||||
".mysql",
|
||||
".pgsql",
|
||||
".sqlite",
|
||||
# Lisp family
|
||||
".lisp",
|
||||
".cl",
|
||||
".el",
|
||||
".clj",
|
||||
".cljc",
|
||||
".cljs",
|
||||
".edn",
|
||||
]
|
||||
|
@ -1,10 +1,11 @@
|
||||
import os
|
||||
import shutil
|
||||
import pytest
|
||||
from langchain_chroma import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
import shutil
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_chroma import Chroma
|
||||
from .fixtures import *
|
||||
from .utility import *
|
||||
from code_rag.rag import RAG
|
||||
from code_rag.doc_tracker import DocumentTracker
|
||||
from code_rag.ollama_wrapper import OllamaWrapper
|
||||
@ -51,40 +52,18 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
|
||||
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test creating a vector database"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
|
||||
# Create files with different extensions
|
||||
files = {
|
||||
"test.py": "def hello():\n print('Hello, World!')",
|
||||
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
|
||||
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
|
||||
"config.toml": "[package]\nname = 'test'",
|
||||
"doc.md": "# Documentation\nThis is a test file."
|
||||
}
|
||||
|
||||
for filename, content in files.items():
|
||||
filepath = os.path.join(docs_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
# Create vector database with default extensions (should include all file types)
|
||||
# Create initial vector database
|
||||
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
||||
|
||||
# Verify it was created
|
||||
assert os.path.exists(rag_pipeline.db_dir)
|
||||
assert vectorstore is not None
|
||||
|
||||
# Check the database has content from all file types
|
||||
# Check the database has content
|
||||
loaded_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||
persist_directory=rag_pipeline.db_dir,
|
||||
embedding_function=rag_pipeline.ollama.embeddings,
|
||||
)
|
||||
# Should have content from all files
|
||||
assert loaded_db._collection.count() > 0
|
||||
|
||||
# Verify each file type is included
|
||||
docs = loaded_db._collection.get()
|
||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||
for filename in files.keys():
|
||||
assert filename in sources, f"File {filename} not found in vector store"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -93,66 +72,45 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test updating a vector database with document changes"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
# Create initial vector database with only Python files
|
||||
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
|
||||
# Create initial vector database
|
||||
rag_pipeline.create_vector_db(force_refresh=True)
|
||||
|
||||
# Get initial count
|
||||
initial_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||
persist_directory=rag_pipeline.db_dir,
|
||||
embedding_function=rag_pipeline.ollama.embeddings,
|
||||
)
|
||||
initial_count = initial_db._collection.count()
|
||||
|
||||
# Make changes to documents
|
||||
# Add files of different types
|
||||
new_files = {
|
||||
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
|
||||
"lib.rs": "fn main() { println!('Hello'); }",
|
||||
"config.toml": "[package]\nname = 'test'"
|
||||
}
|
||||
|
||||
for filename, content in new_files.items():
|
||||
filepath = os.path.join(docs_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
# Add a new document
|
||||
create_test_document(
|
||||
docs_dir, "newdoc.txt", "This is a brand new document for testing."
|
||||
)
|
||||
|
||||
# Update the vector database to include all supported extensions
|
||||
rag_pipeline.create_vector_db() # Use default extensions
|
||||
# Update the vector database
|
||||
rag_pipeline.create_vector_db()
|
||||
|
||||
# Check the database has been updated
|
||||
updated_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||
persist_directory=rag_pipeline.db_dir,
|
||||
embedding_function=rag_pipeline.ollama.embeddings,
|
||||
)
|
||||
assert updated_db._collection.count() > initial_count
|
||||
|
||||
# Verify new files are included
|
||||
docs = updated_db._collection.get()
|
||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||
for filename in new_files.keys():
|
||||
assert filename in sources, f"File {filename} not found in vector store"
|
||||
|
||||
|
||||
# Final integration test - full RAG pipeline
|
||||
@pytest.mark.skipif(
|
||||
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||
)
|
||||
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test the entire RAG pipeline from document processing to querying"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
|
||||
# Create documents with mixed content types
|
||||
test_files = {
|
||||
"python_info.py": """# Python Information
|
||||
def describe_python():
|
||||
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
|
||||
pass""",
|
||||
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
|
||||
}
|
||||
|
||||
for filename, content in test_files.items():
|
||||
filepath = os.path.join(rag_pipeline.docs_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
# Create a specific document with known content
|
||||
test_content = "Python is a high-level programming language known for its readability and versatility."
|
||||
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
|
||||
|
||||
# Create vector database with all default extensions
|
||||
# Create vector database
|
||||
rag_pipeline.create_vector_db(force_refresh=True)
|
||||
|
||||
# Set up RAG
|
||||
@ -166,3 +124,4 @@ def describe_python():
|
||||
# This is a soft test since the exact response will depend on the LLM
|
||||
assert response.strip() != ""
|
||||
assert "programming" in response.lower() or "language" in response.lower()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user