Add more file extensions to accepted inputs

This commit is contained in:
Alex Selimov 2025-03-22 01:00:16 -04:00
parent f63dc7148a
commit 3e7f4e42f7
4 changed files with 267 additions and 1416 deletions

1364
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,27 +1,29 @@
[project] [project]
name = "code-rag" name = "code-rag"
version = "0.1.0" version = "0.1.0"
description = "Simple RAG implementation for use with neovim" description = "Simple RAG implementation"
authors = [{ name = "Alex Selimov", email = "alex@alexselimov.com" }] authors = [
readme = "README.md" { name = "Alex Selimov", email = "alex@alexselimov.com" }
requires-python = ">=3.9,<4.0"
dependencies = [
"langchain (>=0.3.21,<0.4.0)",
"ollama (>=0.4.7,<0.5.0)",
"langchain-community (>=0.3.20,<0.4.0)",
"langchain-ollama (>=0.2.3,<0.3.0)",
"chromadb (>=0.4.0,<0.6.0)",
"unstructured (>=0.17.2,<0.18.0)",
"langchain-chroma (>=0.1.0,<0.2.0)"
] ]
readme = "README.md"
requires-python = "^3.11"
[tool.poetry] [tool.poetry]
packages = [{ include = "code_rag", from = "src" }] packages = [{ include = "code_rag", from = "src" }]
[tool.poetry.dependencies]
python = "^3.11"
langchain = "*"
langchain-community = "*"
langchain-core = "*"
langchain-chroma = "*"
langchain-ollama = "*"
chromadb = "*"
ollama = "*"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^8.3.5" pytest = "^8.3.5"
[build-system] [build-system]
requires = ["poetry-core>=2.0.0,<3.0.0"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@ -1,10 +1,12 @@
import os import os
import uuid import uuid
import glob
from langchain_community.document_loaders import DirectoryLoader, TextLoader from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma from langchain_chroma import Chroma
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from code_rag.doc_tracker import DocumentTracker from code_rag.doc_tracker import DocumentTracker
from code_rag.ollama_wrapper import OllamaWrapper from code_rag.ollama_wrapper import OllamaWrapper
@ -61,10 +63,32 @@ class RAG:
return all_chunks, file_chunk_map return all_chunks, file_chunk_map
def create_vector_db(self, extension=".txt", force_refresh=False): def create_vector_db(self, extensions=None, force_refresh=False):
""" """
Create or update a vector database, with complete handling of changes Create or update a vector database, with complete handling of changes.
Args:
extensions (list[str], optional): List of file extensions to include.
If None, defaults to common programming languages.
force_refresh (bool): Whether to force a complete refresh of the vector database.
""" """
# Set default extensions for common programming languages if none provided
if extensions is None:
extensions = [
# Python
'.py', '.pyi', '.pyx',
# C/C++
'.c', '.cpp', '.cc', '.cxx', '.h', '.hpp', '.hxx',
# Rust
'.rs',
# Documentation
'.txt', '.md',
# Build/Config
'.toml', '.yaml', '.json'
]
elif isinstance(extensions, str):
extensions = [extensions]
# Text splitter configuration # Text splitter configuration
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200 chunk_size=1000, chunk_overlap=200
@ -75,121 +99,79 @@ class RAG:
embeddings = self.ollama.embeddings embeddings = self.ollama.embeddings
print("after embedding") print("after embedding")
if force_refresh: # Load or create vector store
print("Force refresh: Processing all documents") if os.path.exists(self.db_dir) and not force_refresh:
# Load all documents print("Loading existing vector store")
loader = DirectoryLoader(self.docs_dir, glob=f"**/*{extension}") vectorstore = Chroma(
all_documents = loader.load() persist_directory=self.db_dir,
embedding_function=embeddings
if not all_documents: )
print("No documents found to process") else:
# Create an empty vector store print("Creating new vector store")
vectorstore = Chroma( vectorstore = Chroma(
persist_directory=self.db_dir, embedding_function=embeddings persist_directory=self.db_dir,
) embedding_function=embeddings
return vectorstore
# Add unique IDs to each document
for doc in all_documents:
doc.metadata["source"] = os.path.abspath(doc.metadata["source"])
doc.metadata["source_id"] = doc.metadata["source"]
# Split documents
chunks = text_splitter.split_documents(all_documents)
# Add chunk IDs and update tracker
file_chunk_map = {}
for chunk in chunks:
chunk_id = str(uuid.uuid4())
chunk.metadata["chunk_id"] = chunk_id
source = chunk.metadata["source"]
if source not in file_chunk_map:
file_chunk_map[source] = []
file_chunk_map[source].append(chunk_id)
# Update tracker with chunk mappings
for file_path, chunk_ids in file_chunk_map.items():
self.tracker.update_chunk_mappings(file_path, chunk_ids)
print(
f"Processing {len(all_documents)} documents with {len(chunks)} chunks"
) )
# Create new vector store # Find all files that match the extensions
vectorstore = Chroma.from_documents( all_files = set()
documents=chunks, embedding=embeddings, persist_directory=self.db_dir for ext in extensions:
) files = glob.glob(os.path.join(self.docs_dir, f"**/*{ext}"), recursive=True)
all_files.update(files)
if not all_files:
print("No documents found to process")
return vectorstore return vectorstore
# Get changes since last update # Process all files
changed_files = self.tracker.get_changed_files(self.docs_dir) all_documents = []
for file_path in all_files:
if not any(changed_files.values()): try:
print("No document changes detected") with open(file_path, 'r', encoding='utf-8') as f:
# Load existing vector store if available content = f.read()
if os.path.exists(self.db_dir): doc = Document(
return Chroma( page_content=content,
persist_directory=self.db_dir, metadata={
embedding_function=self.ollama.embeddings, "source": os.path.abspath(file_path),
) "source_id": os.path.abspath(file_path)
else: }
print("No vector database exists. Creating from all documents...")
return self.create_vector_db(force_refresh=True)
# Process changes
print(
f"Changes detected - New: {len(changed_files['new'])}, Modified: {len(changed_files['modified'])}, Deleted: {len(changed_files['deleted'])}"
)
# Load existing vector store if it exists
if os.path.exists(self.db_dir):
vectorstore = Chroma(
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
)
# 1. Handle deleted documents
if changed_files["deleted"]:
chunks_to_delete = self.tracker.get_chunks_to_delete(
changed_files["deleted"]
)
if chunks_to_delete:
print(
f"Removing {len(chunks_to_delete)} chunks from deleted documents"
)
# Delete the chunks from vector store
vectorstore._collection.delete(
where={"chunk_id": {"$in": chunks_to_delete}}
) )
all_documents.append(doc)
print(f"Successfully loaded {file_path}")
except Exception as e:
print(f"Error loading file {file_path}: {str(e)}")
# 2. Handle modified documents (delete old chunks first) if not all_documents:
chunks_to_delete_modified = self.tracker.get_chunks_for_modified_files( print("No documents could be loaded")
changed_files["modified"] return vectorstore
# Split documents
chunks = text_splitter.split_documents(all_documents)
if force_refresh:
# Create new vector store from scratch
vectorstore = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
persist_directory=self.db_dir
) )
if chunks_to_delete_modified:
print(
f"Removing {len(chunks_to_delete_modified)} chunks from modified documents"
)
vectorstore._collection.delete(
where={"chunk_id": {"$in": chunks_to_delete_modified}}
)
# 3. Process new and modified documents
files_to_process = changed_files["new"] + changed_files["modified"]
if files_to_process:
chunks, _ = self.process_documents(files_to_process, text_splitter)
print(f"Adding {len(chunks)} new chunks to the vector store")
vectorstore.add_documents(chunks)
else: else:
# If no existing DB, create from all documents # Update existing vector store
print("No existing vector database. Creating from all documents...") # First, get all existing document IDs
return self.create_vector_db(force_refresh=True) existing_docs = vectorstore._collection.get()
existing_sources = {m["source"] for m in existing_docs["metadatas"]}
# Persist changes # Find new documents
vectorstore.persist() new_chunks = [
print(f"Vector database updated at {self.db_dir}") chunk for chunk in chunks
if chunk.metadata["source"] not in existing_sources
]
if new_chunks:
print(f"Adding {len(new_chunks)} new chunks to vector store")
vectorstore.add_documents(new_chunks)
print(f"Vector store updated with {len(chunks)} total chunks")
return vectorstore return vectorstore
def setup_rag(self): def setup_rag(self):

View File

@ -1,11 +1,10 @@
import os import os
import pytest
import shutil import shutil
import pytest
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma from langchain_chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from .fixtures import * from .fixtures import *
from .utility import *
from code_rag.rag import RAG from code_rag.rag import RAG
from code_rag.doc_tracker import DocumentTracker from code_rag.doc_tracker import DocumentTracker
from code_rag.ollama_wrapper import OllamaWrapper from code_rag.ollama_wrapper import OllamaWrapper
@ -52,18 +51,40 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
"""Test creating a vector database""" """Test creating a vector database"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file) rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database
# Create files with different extensions
files = {
"test.py": "def hello():\n print('Hello, World!')",
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
"config.toml": "[package]\nname = 'test'",
"doc.md": "# Documentation\nThis is a test file."
}
for filename, content in files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
f.write(content)
# Create vector database with default extensions (should include all file types)
vectorstore = rag_pipeline.create_vector_db(force_refresh=True) vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
# Verify it was created # Verify it was created
assert os.path.exists(rag_pipeline.db_dir) assert os.path.exists(rag_pipeline.db_dir)
assert vectorstore is not None assert vectorstore is not None
# Check the database has content
# Check the database has content from all file types
loaded_db = Chroma( loaded_db = Chroma(
persist_directory=rag_pipeline.db_dir, persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
embedding_function=rag_pipeline.ollama.embeddings,
) )
# Should have content from all files
assert loaded_db._collection.count() > 0 assert loaded_db._collection.count() > 0
# Verify each file type is included
docs = loaded_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
for filename in files.keys():
assert filename in sources, f"File {filename} not found in vector store"
@pytest.mark.skipif( @pytest.mark.skipif(
@ -72,45 +93,66 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs): def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
"""Test updating a vector database with document changes""" """Test updating a vector database with document changes"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file) rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database # Create initial vector database with only Python files
rag_pipeline.create_vector_db(force_refresh=True) vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
# Get initial count # Get initial count
initial_db = Chroma( initial_db = Chroma(
persist_directory=rag_pipeline.db_dir, persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
embedding_function=rag_pipeline.ollama.embeddings,
) )
initial_count = initial_db._collection.count() initial_count = initial_db._collection.count()
# Make changes to documents # Make changes to documents
# Add a new document # Add files of different types
create_test_document( new_files = {
docs_dir, "newdoc.txt", "This is a brand new document for testing." "newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
) "lib.rs": "fn main() { println!('Hello'); }",
"config.toml": "[package]\nname = 'test'"
}
for filename, content in new_files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
f.write(content)
# Update the vector database # Update the vector database to include all supported extensions
rag_pipeline.create_vector_db() rag_pipeline.create_vector_db() # Use default extensions
# Check the database has been updated # Check the database has been updated
updated_db = Chroma( updated_db = Chroma(
persist_directory=rag_pipeline.db_dir, persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
embedding_function=rag_pipeline.ollama.embeddings,
) )
assert updated_db._collection.count() > initial_count assert updated_db._collection.count() > initial_count
# Verify new files are included
docs = updated_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
for filename in new_files.keys():
assert filename in sources, f"File {filename} not found in vector store"
# Final integration test - full RAG pipeline
@pytest.mark.skipif( @pytest.mark.skipif(
not shutil.which("ollama"), reason="Ollama not installed or not in PATH" not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
) )
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs): def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
"""Test the entire RAG pipeline from document processing to querying""" """Test the entire RAG pipeline from document processing to querying"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file) rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create a specific document with known content
test_content = "Python is a high-level programming language known for its readability and versatility." # Create documents with mixed content types
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content) test_files = {
"python_info.py": """# Python Information
def describe_python():
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
pass""",
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
}
for filename, content in test_files.items():
filepath = os.path.join(rag_pipeline.docs_dir, filename)
with open(filepath, "w") as f:
f.write(content)
# Create vector database # Create vector database with all default extensions
rag_pipeline.create_vector_db(force_refresh=True) rag_pipeline.create_vector_db(force_refresh=True)
# Set up RAG # Set up RAG
@ -124,4 +166,3 @@ def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
# This is a soft test since the exact response will depend on the LLM # This is a soft test since the exact response will depend on the LLM
assert response.strip() != "" assert response.strip() != ""
assert "programming" in response.lower() or "language" in response.lower() assert "programming" in response.lower() or "language" in response.lower()