Compare commits
No commits in common. "550e4b78e248519856c564fcae284784bd09aa99" and "f63dc7148ac6c62bef8218f5597696d255af9996" have entirely different histories.
550e4b78e2
...
f63dc7148a
1368
poetry.lock
generated
1368
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,29 +1,27 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "code-rag"
|
name = "code-rag"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Simple RAG implementation"
|
description = "Simple RAG implementation for use with neovim"
|
||||||
authors = [
|
authors = [{ name = "Alex Selimov", email = "alex@alexselimov.com" }]
|
||||||
{ name = "Alex Selimov", email = "alex@alexselimov.com" }
|
|
||||||
]
|
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = "^3.11"
|
requires-python = ">=3.9,<4.0"
|
||||||
|
dependencies = [
|
||||||
|
"langchain (>=0.3.21,<0.4.0)",
|
||||||
|
"ollama (>=0.4.7,<0.5.0)",
|
||||||
|
"langchain-community (>=0.3.20,<0.4.0)",
|
||||||
|
"langchain-ollama (>=0.2.3,<0.3.0)",
|
||||||
|
"chromadb (>=0.4.0,<0.6.0)",
|
||||||
|
"unstructured (>=0.17.2,<0.18.0)",
|
||||||
|
"langchain-chroma (>=0.1.0,<0.2.0)"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
packages = [{ include = "code_rag", from = "src" }]
|
packages = [{ include = "code_rag", from = "src" }]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = "^3.11"
|
|
||||||
langchain = "*"
|
|
||||||
langchain-community = "*"
|
|
||||||
langchain-core = "*"
|
|
||||||
langchain-chroma = "*"
|
|
||||||
langchain-ollama = "*"
|
|
||||||
chromadb = "*"
|
|
||||||
ollama = "*"
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^8.3.5"
|
pytest = "^8.3.5"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import glob
|
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_chroma import Chroma
|
from langchain_community.vectorstores import Chroma
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
from code_rag.doc_tracker import DocumentTracker
|
from code_rag.doc_tracker import DocumentTracker
|
||||||
from code_rag.ollama_wrapper import OllamaWrapper
|
from code_rag.ollama_wrapper import OllamaWrapper
|
||||||
@ -63,21 +61,10 @@ class RAG:
|
|||||||
|
|
||||||
return all_chunks, file_chunk_map
|
return all_chunks, file_chunk_map
|
||||||
|
|
||||||
def create_vector_db(self, extensions=None, force_refresh=False):
|
def create_vector_db(self, extension=".txt", force_refresh=False):
|
||||||
"""
|
"""
|
||||||
Create or update a vector database, with complete handling of changes.
|
Create or update a vector database, with complete handling of changes
|
||||||
|
|
||||||
Args:
|
|
||||||
extensions (list[str], optional): List of file extensions to include.
|
|
||||||
If None, defaults to common programming languages.
|
|
||||||
force_refresh (bool): Whether to force a complete refresh of the vector database.
|
|
||||||
"""
|
"""
|
||||||
# Set default extensions for common programming languages if none provided
|
|
||||||
if extensions is None:
|
|
||||||
extensions = default_file_extensions()
|
|
||||||
elif isinstance(extensions, str):
|
|
||||||
extensions = [extensions]
|
|
||||||
|
|
||||||
# Text splitter configuration
|
# Text splitter configuration
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000, chunk_overlap=200
|
chunk_size=1000, chunk_overlap=200
|
||||||
@ -88,76 +75,121 @@ class RAG:
|
|||||||
embeddings = self.ollama.embeddings
|
embeddings = self.ollama.embeddings
|
||||||
print("after embedding")
|
print("after embedding")
|
||||||
|
|
||||||
# Load or create vector store
|
|
||||||
if os.path.exists(self.db_dir) and not force_refresh:
|
|
||||||
print("Loading existing vector store")
|
|
||||||
vectorstore = Chroma(
|
|
||||||
persist_directory=self.db_dir, embedding_function=embeddings
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("Creating new vector store")
|
|
||||||
vectorstore = Chroma(
|
|
||||||
persist_directory=self.db_dir, embedding_function=embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find all files that match the extensions
|
|
||||||
all_files = set()
|
|
||||||
for ext in extensions:
|
|
||||||
files = glob.glob(os.path.join(self.docs_dir, f"**/*{ext}"), recursive=True)
|
|
||||||
all_files.update(files)
|
|
||||||
|
|
||||||
if not all_files:
|
|
||||||
print("No documents found to process")
|
|
||||||
return vectorstore
|
|
||||||
|
|
||||||
# Process all files
|
|
||||||
all_documents = []
|
|
||||||
for file_path in all_files:
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
doc = Document(
|
|
||||||
page_content=content,
|
|
||||||
metadata={
|
|
||||||
"source": os.path.abspath(file_path),
|
|
||||||
"source_id": os.path.abspath(file_path),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
all_documents.append(doc)
|
|
||||||
print(f"Successfully loaded {file_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading file {file_path}: {str(e)}")
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents could be loaded")
|
|
||||||
return vectorstore
|
|
||||||
|
|
||||||
# Split documents
|
|
||||||
chunks = text_splitter.split_documents(all_documents)
|
|
||||||
|
|
||||||
if force_refresh:
|
if force_refresh:
|
||||||
# Create new vector store from scratch
|
print("Force refresh: Processing all documents")
|
||||||
|
# Load all documents
|
||||||
|
loader = DirectoryLoader(self.docs_dir, glob=f"**/*{extension}")
|
||||||
|
all_documents = loader.load()
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents found to process")
|
||||||
|
# Create an empty vector store
|
||||||
|
vectorstore = Chroma(
|
||||||
|
persist_directory=self.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
# Add unique IDs to each document
|
||||||
|
for doc in all_documents:
|
||||||
|
doc.metadata["source"] = os.path.abspath(doc.metadata["source"])
|
||||||
|
doc.metadata["source_id"] = doc.metadata["source"]
|
||||||
|
|
||||||
|
# Split documents
|
||||||
|
chunks = text_splitter.split_documents(all_documents)
|
||||||
|
|
||||||
|
# Add chunk IDs and update tracker
|
||||||
|
file_chunk_map = {}
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_id = str(uuid.uuid4())
|
||||||
|
chunk.metadata["chunk_id"] = chunk_id
|
||||||
|
|
||||||
|
source = chunk.metadata["source"]
|
||||||
|
if source not in file_chunk_map:
|
||||||
|
file_chunk_map[source] = []
|
||||||
|
|
||||||
|
file_chunk_map[source].append(chunk_id)
|
||||||
|
|
||||||
|
# Update tracker with chunk mappings
|
||||||
|
for file_path, chunk_ids in file_chunk_map.items():
|
||||||
|
self.tracker.update_chunk_mappings(file_path, chunk_ids)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Processing {len(all_documents)} documents with {len(chunks)} chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new vector store
|
||||||
vectorstore = Chroma.from_documents(
|
vectorstore = Chroma.from_documents(
|
||||||
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
|
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
|
||||||
)
|
)
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
# Get changes since last update
|
||||||
|
changed_files = self.tracker.get_changed_files(self.docs_dir)
|
||||||
|
|
||||||
|
if not any(changed_files.values()):
|
||||||
|
print("No document changes detected")
|
||||||
|
# Load existing vector store if available
|
||||||
|
if os.path.exists(self.db_dir):
|
||||||
|
return Chroma(
|
||||||
|
persist_directory=self.db_dir,
|
||||||
|
embedding_function=self.ollama.embeddings,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("No vector database exists. Creating from all documents...")
|
||||||
|
return self.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
|
# Process changes
|
||||||
|
print(
|
||||||
|
f"Changes detected - New: {len(changed_files['new'])}, Modified: {len(changed_files['modified'])}, Deleted: {len(changed_files['deleted'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load existing vector store if it exists
|
||||||
|
if os.path.exists(self.db_dir):
|
||||||
|
vectorstore = Chroma(
|
||||||
|
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Handle deleted documents
|
||||||
|
if changed_files["deleted"]:
|
||||||
|
chunks_to_delete = self.tracker.get_chunks_to_delete(
|
||||||
|
changed_files["deleted"]
|
||||||
|
)
|
||||||
|
if chunks_to_delete:
|
||||||
|
print(
|
||||||
|
f"Removing {len(chunks_to_delete)} chunks from deleted documents"
|
||||||
|
)
|
||||||
|
# Delete the chunks from vector store
|
||||||
|
vectorstore._collection.delete(
|
||||||
|
where={"chunk_id": {"$in": chunks_to_delete}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Handle modified documents (delete old chunks first)
|
||||||
|
chunks_to_delete_modified = self.tracker.get_chunks_for_modified_files(
|
||||||
|
changed_files["modified"]
|
||||||
|
)
|
||||||
|
if chunks_to_delete_modified:
|
||||||
|
print(
|
||||||
|
f"Removing {len(chunks_to_delete_modified)} chunks from modified documents"
|
||||||
|
)
|
||||||
|
vectorstore._collection.delete(
|
||||||
|
where={"chunk_id": {"$in": chunks_to_delete_modified}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Process new and modified documents
|
||||||
|
files_to_process = changed_files["new"] + changed_files["modified"]
|
||||||
|
if files_to_process:
|
||||||
|
chunks, _ = self.process_documents(files_to_process, text_splitter)
|
||||||
|
print(f"Adding {len(chunks)} new chunks to the vector store")
|
||||||
|
vectorstore.add_documents(chunks)
|
||||||
else:
|
else:
|
||||||
# Update existing vector store
|
# If no existing DB, create from all documents
|
||||||
# First, get all existing document IDs
|
print("No existing vector database. Creating from all documents...")
|
||||||
existing_docs = vectorstore._collection.get()
|
return self.create_vector_db(force_refresh=True)
|
||||||
existing_sources = {m["source"] for m in existing_docs["metadatas"]}
|
|
||||||
|
|
||||||
# Find new documents
|
# Persist changes
|
||||||
new_chunks = [
|
vectorstore.persist()
|
||||||
chunk
|
print(f"Vector database updated at {self.db_dir}")
|
||||||
for chunk in chunks
|
|
||||||
if chunk.metadata["source"] not in existing_sources
|
|
||||||
]
|
|
||||||
|
|
||||||
if new_chunks:
|
|
||||||
print(f"Adding {len(new_chunks)} new chunks to vector store")
|
|
||||||
vectorstore.add_documents(new_chunks)
|
|
||||||
|
|
||||||
print(f"Vector store updated with {len(chunks)} total chunks")
|
|
||||||
return vectorstore
|
return vectorstore
|
||||||
|
|
||||||
def setup_rag(self):
|
def setup_rag(self):
|
||||||
@ -210,136 +242,3 @@ class RAG:
|
|||||||
"""
|
"""
|
||||||
response = rag_chain.invoke({"query": query})
|
response = rag_chain.invoke({"query": query})
|
||||||
return response["result"]
|
return response["result"]
|
||||||
|
|
||||||
|
|
||||||
def default_file_extensions():
|
|
||||||
"""Return the default file extensions representing common plain text and code files"""
|
|
||||||
return [
|
|
||||||
# Python
|
|
||||||
".py",
|
|
||||||
".pyi",
|
|
||||||
".pyx",
|
|
||||||
".pyc",
|
|
||||||
".pyd",
|
|
||||||
".pyw",
|
|
||||||
# C/C++
|
|
||||||
".c",
|
|
||||||
".cpp",
|
|
||||||
".cc",
|
|
||||||
".cxx",
|
|
||||||
".h",
|
|
||||||
".hpp",
|
|
||||||
".hxx",
|
|
||||||
".inc",
|
|
||||||
".inl",
|
|
||||||
".ipp",
|
|
||||||
# Rust
|
|
||||||
".rs",
|
|
||||||
".rlib",
|
|
||||||
".rmeta",
|
|
||||||
# Java
|
|
||||||
".java",
|
|
||||||
".jsp",
|
|
||||||
".jav",
|
|
||||||
".jar",
|
|
||||||
".class",
|
|
||||||
".kt",
|
|
||||||
".kts",
|
|
||||||
".groovy",
|
|
||||||
# Web
|
|
||||||
".html",
|
|
||||||
".htm",
|
|
||||||
".css",
|
|
||||||
".scss",
|
|
||||||
".sass",
|
|
||||||
".less",
|
|
||||||
".js",
|
|
||||||
".jsx",
|
|
||||||
".ts",
|
|
||||||
".tsx",
|
|
||||||
".vue",
|
|
||||||
".svelte",
|
|
||||||
# Fortran
|
|
||||||
".f",
|
|
||||||
".for",
|
|
||||||
".f90",
|
|
||||||
".f95",
|
|
||||||
".f03",
|
|
||||||
".f08",
|
|
||||||
# Go
|
|
||||||
".go",
|
|
||||||
".mod",
|
|
||||||
# Ruby
|
|
||||||
".rb",
|
|
||||||
".rbw",
|
|
||||||
".rake",
|
|
||||||
".gemspec",
|
|
||||||
# PHP
|
|
||||||
".php",
|
|
||||||
".phtml",
|
|
||||||
".php3",
|
|
||||||
".php4",
|
|
||||||
".php5",
|
|
||||||
".phps",
|
|
||||||
# C#
|
|
||||||
".cs",
|
|
||||||
".csx",
|
|
||||||
".vb",
|
|
||||||
# Swift
|
|
||||||
".swift",
|
|
||||||
".swiftmodule",
|
|
||||||
# Shell/Scripts
|
|
||||||
".sh",
|
|
||||||
".bash",
|
|
||||||
".zsh",
|
|
||||||
".fish",
|
|
||||||
".ps1",
|
|
||||||
".bat",
|
|
||||||
".cmd",
|
|
||||||
# Scala
|
|
||||||
".scala",
|
|
||||||
".sc",
|
|
||||||
# Haskell
|
|
||||||
".hs",
|
|
||||||
".lhs",
|
|
||||||
".hsc",
|
|
||||||
# Lua
|
|
||||||
".lua",
|
|
||||||
".luac",
|
|
||||||
# R
|
|
||||||
".r",
|
|
||||||
".rmd",
|
|
||||||
".rds",
|
|
||||||
# Perl
|
|
||||||
".pl",
|
|
||||||
".pm",
|
|
||||||
".t",
|
|
||||||
# Documentation
|
|
||||||
".txt",
|
|
||||||
".md",
|
|
||||||
".rst",
|
|
||||||
".adoc",
|
|
||||||
".wiki",
|
|
||||||
# Build/Config
|
|
||||||
".toml",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".json",
|
|
||||||
".xml",
|
|
||||||
".ini",
|
|
||||||
".conf",
|
|
||||||
".cfg",
|
|
||||||
# SQL
|
|
||||||
".sql",
|
|
||||||
".mysql",
|
|
||||||
".pgsql",
|
|
||||||
".sqlite",
|
|
||||||
# Lisp family
|
|
||||||
".lisp",
|
|
||||||
".cl",
|
|
||||||
".el",
|
|
||||||
".clj",
|
|
||||||
".cljc",
|
|
||||||
".cljs",
|
|
||||||
".edn",
|
|
||||||
]
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_chroma import Chroma
|
import shutil
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_chroma import Chroma
|
||||||
from .fixtures import *
|
from .fixtures import *
|
||||||
|
from .utility import *
|
||||||
from code_rag.rag import RAG
|
from code_rag.rag import RAG
|
||||||
from code_rag.doc_tracker import DocumentTracker
|
from code_rag.doc_tracker import DocumentTracker
|
||||||
from code_rag.ollama_wrapper import OllamaWrapper
|
from code_rag.ollama_wrapper import OllamaWrapper
|
||||||
@ -51,40 +52,18 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
|
|||||||
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test creating a vector database"""
|
"""Test creating a vector database"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
# Create initial vector database
|
||||||
# Create files with different extensions
|
|
||||||
files = {
|
|
||||||
"test.py": "def hello():\n print('Hello, World!')",
|
|
||||||
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
|
|
||||||
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
|
|
||||||
"config.toml": "[package]\nname = 'test'",
|
|
||||||
"doc.md": "# Documentation\nThis is a test file."
|
|
||||||
}
|
|
||||||
|
|
||||||
for filename, content in files.items():
|
|
||||||
filepath = os.path.join(docs_dir, filename)
|
|
||||||
with open(filepath, "w") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
# Create vector database with default extensions (should include all file types)
|
|
||||||
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
# Verify it was created
|
# Verify it was created
|
||||||
assert os.path.exists(rag_pipeline.db_dir)
|
assert os.path.exists(rag_pipeline.db_dir)
|
||||||
assert vectorstore is not None
|
assert vectorstore is not None
|
||||||
|
# Check the database has content
|
||||||
# Check the database has content from all file types
|
|
||||||
loaded_db = Chroma(
|
loaded_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
persist_directory=rag_pipeline.db_dir,
|
||||||
|
embedding_function=rag_pipeline.ollama.embeddings,
|
||||||
)
|
)
|
||||||
# Should have content from all files
|
|
||||||
assert loaded_db._collection.count() > 0
|
assert loaded_db._collection.count() > 0
|
||||||
|
|
||||||
# Verify each file type is included
|
|
||||||
docs = loaded_db._collection.get()
|
|
||||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
|
||||||
for filename in files.keys():
|
|
||||||
assert filename in sources, f"File {filename} not found in vector store"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -93,66 +72,45 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
|||||||
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test updating a vector database with document changes"""
|
"""Test updating a vector database with document changes"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
# Create initial vector database with only Python files
|
# Create initial vector database
|
||||||
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
|
rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
# Get initial count
|
# Get initial count
|
||||||
initial_db = Chroma(
|
initial_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
persist_directory=rag_pipeline.db_dir,
|
||||||
|
embedding_function=rag_pipeline.ollama.embeddings,
|
||||||
)
|
)
|
||||||
initial_count = initial_db._collection.count()
|
initial_count = initial_db._collection.count()
|
||||||
|
|
||||||
# Make changes to documents
|
# Make changes to documents
|
||||||
# Add files of different types
|
# Add a new document
|
||||||
new_files = {
|
create_test_document(
|
||||||
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
|
docs_dir, "newdoc.txt", "This is a brand new document for testing."
|
||||||
"lib.rs": "fn main() { println!('Hello'); }",
|
)
|
||||||
"config.toml": "[package]\nname = 'test'"
|
|
||||||
}
|
|
||||||
|
|
||||||
for filename, content in new_files.items():
|
|
||||||
filepath = os.path.join(docs_dir, filename)
|
|
||||||
with open(filepath, "w") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
# Update the vector database to include all supported extensions
|
# Update the vector database
|
||||||
rag_pipeline.create_vector_db() # Use default extensions
|
rag_pipeline.create_vector_db()
|
||||||
|
|
||||||
# Check the database has been updated
|
# Check the database has been updated
|
||||||
updated_db = Chroma(
|
updated_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
persist_directory=rag_pipeline.db_dir,
|
||||||
|
embedding_function=rag_pipeline.ollama.embeddings,
|
||||||
)
|
)
|
||||||
assert updated_db._collection.count() > initial_count
|
assert updated_db._collection.count() > initial_count
|
||||||
|
|
||||||
# Verify new files are included
|
|
||||||
docs = updated_db._collection.get()
|
|
||||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
|
||||||
for filename in new_files.keys():
|
|
||||||
assert filename in sources, f"File {filename} not found in vector store"
|
|
||||||
|
|
||||||
|
|
||||||
|
# Final integration test - full RAG pipeline
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||||
)
|
)
|
||||||
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test the entire RAG pipeline from document processing to querying"""
|
"""Test the entire RAG pipeline from document processing to querying"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
# Create a specific document with known content
|
||||||
# Create documents with mixed content types
|
test_content = "Python is a high-level programming language known for its readability and versatility."
|
||||||
test_files = {
|
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
|
||||||
"python_info.py": """# Python Information
|
|
||||||
def describe_python():
|
|
||||||
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
|
|
||||||
pass""",
|
|
||||||
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
|
|
||||||
}
|
|
||||||
|
|
||||||
for filename, content in test_files.items():
|
|
||||||
filepath = os.path.join(rag_pipeline.docs_dir, filename)
|
|
||||||
with open(filepath, "w") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
# Create vector database with all default extensions
|
# Create vector database
|
||||||
rag_pipeline.create_vector_db(force_refresh=True)
|
rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
# Set up RAG
|
# Set up RAG
|
||||||
@ -166,3 +124,4 @@ def describe_python():
|
|||||||
# This is a soft test since the exact response will depend on the LLM
|
# This is a soft test since the exact response will depend on the LLM
|
||||||
assert response.strip() != ""
|
assert response.strip() != ""
|
||||||
assert "programming" in response.lower() or "language" in response.lower()
|
assert "programming" in response.lower() or "language" in response.lower()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user