Compare commits
2 Commits
f63dc7148a
...
550e4b78e2
Author | SHA1 | Date | |
---|---|---|---|
550e4b78e2 | |||
3e7f4e42f7 |
1364
poetry.lock
generated
1364
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,27 +1,29 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "code-rag"
|
name = "code-rag"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Simple RAG implementation for use with neovim"
|
description = "Simple RAG implementation"
|
||||||
authors = [{ name = "Alex Selimov", email = "alex@alexselimov.com" }]
|
authors = [
|
||||||
readme = "README.md"
|
{ name = "Alex Selimov", email = "alex@alexselimov.com" }
|
||||||
requires-python = ">=3.9,<4.0"
|
|
||||||
dependencies = [
|
|
||||||
"langchain (>=0.3.21,<0.4.0)",
|
|
||||||
"ollama (>=0.4.7,<0.5.0)",
|
|
||||||
"langchain-community (>=0.3.20,<0.4.0)",
|
|
||||||
"langchain-ollama (>=0.2.3,<0.3.0)",
|
|
||||||
"chromadb (>=0.4.0,<0.6.0)",
|
|
||||||
"unstructured (>=0.17.2,<0.18.0)",
|
|
||||||
"langchain-chroma (>=0.1.0,<0.2.0)"
|
|
||||||
]
|
]
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = "^3.11"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
packages = [{ include = "code_rag", from = "src" }]
|
packages = [{ include = "code_rag", from = "src" }]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.11"
|
||||||
|
langchain = "*"
|
||||||
|
langchain-community = "*"
|
||||||
|
langchain-core = "*"
|
||||||
|
langchain-chroma = "*"
|
||||||
|
langchain-ollama = "*"
|
||||||
|
chromadb = "*"
|
||||||
|
ollama = "*"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^8.3.5"
|
pytest = "^8.3.5"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import glob
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.vectorstores import Chroma
|
from langchain_chroma import Chroma
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from code_rag.doc_tracker import DocumentTracker
|
from code_rag.doc_tracker import DocumentTracker
|
||||||
from code_rag.ollama_wrapper import OllamaWrapper
|
from code_rag.ollama_wrapper import OllamaWrapper
|
||||||
@ -61,10 +63,21 @@ class RAG:
|
|||||||
|
|
||||||
return all_chunks, file_chunk_map
|
return all_chunks, file_chunk_map
|
||||||
|
|
||||||
def create_vector_db(self, extension=".txt", force_refresh=False):
|
def create_vector_db(self, extensions=None, force_refresh=False):
|
||||||
"""
|
"""
|
||||||
Create or update a vector database, with complete handling of changes
|
Create or update a vector database, with complete handling of changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extensions (list[str], optional): List of file extensions to include.
|
||||||
|
If None, defaults to common programming languages.
|
||||||
|
force_refresh (bool): Whether to force a complete refresh of the vector database.
|
||||||
"""
|
"""
|
||||||
|
# Set default extensions for common programming languages if none provided
|
||||||
|
if extensions is None:
|
||||||
|
extensions = default_file_extensions()
|
||||||
|
elif isinstance(extensions, str):
|
||||||
|
extensions = [extensions]
|
||||||
|
|
||||||
# Text splitter configuration
|
# Text splitter configuration
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000, chunk_overlap=200
|
chunk_size=1000, chunk_overlap=200
|
||||||
@ -75,121 +88,76 @@ class RAG:
|
|||||||
embeddings = self.ollama.embeddings
|
embeddings = self.ollama.embeddings
|
||||||
print("after embedding")
|
print("after embedding")
|
||||||
|
|
||||||
if force_refresh:
|
# Load or create vector store
|
||||||
print("Force refresh: Processing all documents")
|
if os.path.exists(self.db_dir) and not force_refresh:
|
||||||
# Load all documents
|
print("Loading existing vector store")
|
||||||
loader = DirectoryLoader(self.docs_dir, glob=f"**/*{extension}")
|
vectorstore = Chroma(
|
||||||
all_documents = loader.load()
|
persist_directory=self.db_dir, embedding_function=embeddings
|
||||||
|
)
|
||||||
if not all_documents:
|
else:
|
||||||
print("No documents found to process")
|
print("Creating new vector store")
|
||||||
# Create an empty vector store
|
vectorstore = Chroma(
|
||||||
vectorstore = Chroma(
|
persist_directory=self.db_dir, embedding_function=embeddings
|
||||||
persist_directory=self.db_dir, embedding_function=embeddings
|
|
||||||
)
|
|
||||||
return vectorstore
|
|
||||||
|
|
||||||
# Add unique IDs to each document
|
|
||||||
for doc in all_documents:
|
|
||||||
doc.metadata["source"] = os.path.abspath(doc.metadata["source"])
|
|
||||||
doc.metadata["source_id"] = doc.metadata["source"]
|
|
||||||
|
|
||||||
# Split documents
|
|
||||||
chunks = text_splitter.split_documents(all_documents)
|
|
||||||
|
|
||||||
# Add chunk IDs and update tracker
|
|
||||||
file_chunk_map = {}
|
|
||||||
for chunk in chunks:
|
|
||||||
chunk_id = str(uuid.uuid4())
|
|
||||||
chunk.metadata["chunk_id"] = chunk_id
|
|
||||||
|
|
||||||
source = chunk.metadata["source"]
|
|
||||||
if source not in file_chunk_map:
|
|
||||||
file_chunk_map[source] = []
|
|
||||||
|
|
||||||
file_chunk_map[source].append(chunk_id)
|
|
||||||
|
|
||||||
# Update tracker with chunk mappings
|
|
||||||
for file_path, chunk_ids in file_chunk_map.items():
|
|
||||||
self.tracker.update_chunk_mappings(file_path, chunk_ids)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Processing {len(all_documents)} documents with {len(chunks)} chunks"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create new vector store
|
# Find all files that match the extensions
|
||||||
|
all_files = set()
|
||||||
|
for ext in extensions:
|
||||||
|
files = glob.glob(os.path.join(self.docs_dir, f"**/*{ext}"), recursive=True)
|
||||||
|
all_files.update(files)
|
||||||
|
|
||||||
|
if not all_files:
|
||||||
|
print("No documents found to process")
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
# Process all files
|
||||||
|
all_documents = []
|
||||||
|
for file_path in all_files:
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
doc = Document(
|
||||||
|
page_content=content,
|
||||||
|
metadata={
|
||||||
|
"source": os.path.abspath(file_path),
|
||||||
|
"source_id": os.path.abspath(file_path),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
all_documents.append(doc)
|
||||||
|
print(f"Successfully loaded {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading file {file_path}: {str(e)}")
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents could be loaded")
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
# Split documents
|
||||||
|
chunks = text_splitter.split_documents(all_documents)
|
||||||
|
|
||||||
|
if force_refresh:
|
||||||
|
# Create new vector store from scratch
|
||||||
vectorstore = Chroma.from_documents(
|
vectorstore = Chroma.from_documents(
|
||||||
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
|
documents=chunks, embedding=embeddings, persist_directory=self.db_dir
|
||||||
)
|
)
|
||||||
return vectorstore
|
|
||||||
|
|
||||||
# Get changes since last update
|
|
||||||
changed_files = self.tracker.get_changed_files(self.docs_dir)
|
|
||||||
|
|
||||||
if not any(changed_files.values()):
|
|
||||||
print("No document changes detected")
|
|
||||||
# Load existing vector store if available
|
|
||||||
if os.path.exists(self.db_dir):
|
|
||||||
return Chroma(
|
|
||||||
persist_directory=self.db_dir,
|
|
||||||
embedding_function=self.ollama.embeddings,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("No vector database exists. Creating from all documents...")
|
|
||||||
return self.create_vector_db(force_refresh=True)
|
|
||||||
|
|
||||||
# Process changes
|
|
||||||
print(
|
|
||||||
f"Changes detected - New: {len(changed_files['new'])}, Modified: {len(changed_files['modified'])}, Deleted: {len(changed_files['deleted'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load existing vector store if it exists
|
|
||||||
if os.path.exists(self.db_dir):
|
|
||||||
vectorstore = Chroma(
|
|
||||||
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Handle deleted documents
|
|
||||||
if changed_files["deleted"]:
|
|
||||||
chunks_to_delete = self.tracker.get_chunks_to_delete(
|
|
||||||
changed_files["deleted"]
|
|
||||||
)
|
|
||||||
if chunks_to_delete:
|
|
||||||
print(
|
|
||||||
f"Removing {len(chunks_to_delete)} chunks from deleted documents"
|
|
||||||
)
|
|
||||||
# Delete the chunks from vector store
|
|
||||||
vectorstore._collection.delete(
|
|
||||||
where={"chunk_id": {"$in": chunks_to_delete}}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Handle modified documents (delete old chunks first)
|
|
||||||
chunks_to_delete_modified = self.tracker.get_chunks_for_modified_files(
|
|
||||||
changed_files["modified"]
|
|
||||||
)
|
|
||||||
if chunks_to_delete_modified:
|
|
||||||
print(
|
|
||||||
f"Removing {len(chunks_to_delete_modified)} chunks from modified documents"
|
|
||||||
)
|
|
||||||
vectorstore._collection.delete(
|
|
||||||
where={"chunk_id": {"$in": chunks_to_delete_modified}}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Process new and modified documents
|
|
||||||
files_to_process = changed_files["new"] + changed_files["modified"]
|
|
||||||
if files_to_process:
|
|
||||||
chunks, _ = self.process_documents(files_to_process, text_splitter)
|
|
||||||
print(f"Adding {len(chunks)} new chunks to the vector store")
|
|
||||||
vectorstore.add_documents(chunks)
|
|
||||||
else:
|
else:
|
||||||
# If no existing DB, create from all documents
|
# Update existing vector store
|
||||||
print("No existing vector database. Creating from all documents...")
|
# First, get all existing document IDs
|
||||||
return self.create_vector_db(force_refresh=True)
|
existing_docs = vectorstore._collection.get()
|
||||||
|
existing_sources = {m["source"] for m in existing_docs["metadatas"]}
|
||||||
|
|
||||||
# Persist changes
|
# Find new documents
|
||||||
vectorstore.persist()
|
new_chunks = [
|
||||||
print(f"Vector database updated at {self.db_dir}")
|
chunk
|
||||||
|
for chunk in chunks
|
||||||
|
if chunk.metadata["source"] not in existing_sources
|
||||||
|
]
|
||||||
|
|
||||||
|
if new_chunks:
|
||||||
|
print(f"Adding {len(new_chunks)} new chunks to vector store")
|
||||||
|
vectorstore.add_documents(new_chunks)
|
||||||
|
|
||||||
|
print(f"Vector store updated with {len(chunks)} total chunks")
|
||||||
return vectorstore
|
return vectorstore
|
||||||
|
|
||||||
def setup_rag(self):
|
def setup_rag(self):
|
||||||
@ -242,3 +210,136 @@ class RAG:
|
|||||||
"""
|
"""
|
||||||
response = rag_chain.invoke({"query": query})
|
response = rag_chain.invoke({"query": query})
|
||||||
return response["result"]
|
return response["result"]
|
||||||
|
|
||||||
|
|
||||||
|
def default_file_extensions():
|
||||||
|
"""Return the default file extensions representing common plain text and code files"""
|
||||||
|
return [
|
||||||
|
# Python
|
||||||
|
".py",
|
||||||
|
".pyi",
|
||||||
|
".pyx",
|
||||||
|
".pyc",
|
||||||
|
".pyd",
|
||||||
|
".pyw",
|
||||||
|
# C/C++
|
||||||
|
".c",
|
||||||
|
".cpp",
|
||||||
|
".cc",
|
||||||
|
".cxx",
|
||||||
|
".h",
|
||||||
|
".hpp",
|
||||||
|
".hxx",
|
||||||
|
".inc",
|
||||||
|
".inl",
|
||||||
|
".ipp",
|
||||||
|
# Rust
|
||||||
|
".rs",
|
||||||
|
".rlib",
|
||||||
|
".rmeta",
|
||||||
|
# Java
|
||||||
|
".java",
|
||||||
|
".jsp",
|
||||||
|
".jav",
|
||||||
|
".jar",
|
||||||
|
".class",
|
||||||
|
".kt",
|
||||||
|
".kts",
|
||||||
|
".groovy",
|
||||||
|
# Web
|
||||||
|
".html",
|
||||||
|
".htm",
|
||||||
|
".css",
|
||||||
|
".scss",
|
||||||
|
".sass",
|
||||||
|
".less",
|
||||||
|
".js",
|
||||||
|
".jsx",
|
||||||
|
".ts",
|
||||||
|
".tsx",
|
||||||
|
".vue",
|
||||||
|
".svelte",
|
||||||
|
# Fortran
|
||||||
|
".f",
|
||||||
|
".for",
|
||||||
|
".f90",
|
||||||
|
".f95",
|
||||||
|
".f03",
|
||||||
|
".f08",
|
||||||
|
# Go
|
||||||
|
".go",
|
||||||
|
".mod",
|
||||||
|
# Ruby
|
||||||
|
".rb",
|
||||||
|
".rbw",
|
||||||
|
".rake",
|
||||||
|
".gemspec",
|
||||||
|
# PHP
|
||||||
|
".php",
|
||||||
|
".phtml",
|
||||||
|
".php3",
|
||||||
|
".php4",
|
||||||
|
".php5",
|
||||||
|
".phps",
|
||||||
|
# C#
|
||||||
|
".cs",
|
||||||
|
".csx",
|
||||||
|
".vb",
|
||||||
|
# Swift
|
||||||
|
".swift",
|
||||||
|
".swiftmodule",
|
||||||
|
# Shell/Scripts
|
||||||
|
".sh",
|
||||||
|
".bash",
|
||||||
|
".zsh",
|
||||||
|
".fish",
|
||||||
|
".ps1",
|
||||||
|
".bat",
|
||||||
|
".cmd",
|
||||||
|
# Scala
|
||||||
|
".scala",
|
||||||
|
".sc",
|
||||||
|
# Haskell
|
||||||
|
".hs",
|
||||||
|
".lhs",
|
||||||
|
".hsc",
|
||||||
|
# Lua
|
||||||
|
".lua",
|
||||||
|
".luac",
|
||||||
|
# R
|
||||||
|
".r",
|
||||||
|
".rmd",
|
||||||
|
".rds",
|
||||||
|
# Perl
|
||||||
|
".pl",
|
||||||
|
".pm",
|
||||||
|
".t",
|
||||||
|
# Documentation
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".rst",
|
||||||
|
".adoc",
|
||||||
|
".wiki",
|
||||||
|
# Build/Config
|
||||||
|
".toml",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".ini",
|
||||||
|
".conf",
|
||||||
|
".cfg",
|
||||||
|
# SQL
|
||||||
|
".sql",
|
||||||
|
".mysql",
|
||||||
|
".pgsql",
|
||||||
|
".sqlite",
|
||||||
|
# Lisp family
|
||||||
|
".lisp",
|
||||||
|
".cl",
|
||||||
|
".el",
|
||||||
|
".clj",
|
||||||
|
".cljc",
|
||||||
|
".cljs",
|
||||||
|
".edn",
|
||||||
|
]
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import pytest
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import pytest
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
from .fixtures import *
|
from .fixtures import *
|
||||||
from .utility import *
|
|
||||||
from code_rag.rag import RAG
|
from code_rag.rag import RAG
|
||||||
from code_rag.doc_tracker import DocumentTracker
|
from code_rag.doc_tracker import DocumentTracker
|
||||||
from code_rag.ollama_wrapper import OllamaWrapper
|
from code_rag.ollama_wrapper import OllamaWrapper
|
||||||
@ -52,19 +51,41 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
|
|||||||
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test creating a vector database"""
|
"""Test creating a vector database"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
# Create initial vector database
|
|
||||||
|
# Create files with different extensions
|
||||||
|
files = {
|
||||||
|
"test.py": "def hello():\n print('Hello, World!')",
|
||||||
|
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
|
||||||
|
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
|
||||||
|
"config.toml": "[package]\nname = 'test'",
|
||||||
|
"doc.md": "# Documentation\nThis is a test file."
|
||||||
|
}
|
||||||
|
|
||||||
|
for filename, content in files.items():
|
||||||
|
filepath = os.path.join(docs_dir, filename)
|
||||||
|
with open(filepath, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# Create vector database with default extensions (should include all file types)
|
||||||
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
vectorstore = rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
# Verify it was created
|
# Verify it was created
|
||||||
assert os.path.exists(rag_pipeline.db_dir)
|
assert os.path.exists(rag_pipeline.db_dir)
|
||||||
assert vectorstore is not None
|
assert vectorstore is not None
|
||||||
# Check the database has content
|
|
||||||
|
# Check the database has content from all file types
|
||||||
loaded_db = Chroma(
|
loaded_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir,
|
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||||
embedding_function=rag_pipeline.ollama.embeddings,
|
|
||||||
)
|
)
|
||||||
|
# Should have content from all files
|
||||||
assert loaded_db._collection.count() > 0
|
assert loaded_db._collection.count() > 0
|
||||||
|
|
||||||
|
# Verify each file type is included
|
||||||
|
docs = loaded_db._collection.get()
|
||||||
|
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||||
|
for filename in files.keys():
|
||||||
|
assert filename in sources, f"File {filename} not found in vector store"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||||
@ -72,45 +93,66 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
|||||||
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test updating a vector database with document changes"""
|
"""Test updating a vector database with document changes"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
# Create initial vector database
|
# Create initial vector database with only Python files
|
||||||
rag_pipeline.create_vector_db(force_refresh=True)
|
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
|
||||||
|
|
||||||
# Get initial count
|
# Get initial count
|
||||||
initial_db = Chroma(
|
initial_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir,
|
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||||
embedding_function=rag_pipeline.ollama.embeddings,
|
|
||||||
)
|
)
|
||||||
initial_count = initial_db._collection.count()
|
initial_count = initial_db._collection.count()
|
||||||
|
|
||||||
# Make changes to documents
|
# Make changes to documents
|
||||||
# Add a new document
|
# Add files of different types
|
||||||
create_test_document(
|
new_files = {
|
||||||
docs_dir, "newdoc.txt", "This is a brand new document for testing."
|
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
|
||||||
)
|
"lib.rs": "fn main() { println!('Hello'); }",
|
||||||
|
"config.toml": "[package]\nname = 'test'"
|
||||||
|
}
|
||||||
|
|
||||||
# Update the vector database
|
for filename, content in new_files.items():
|
||||||
rag_pipeline.create_vector_db()
|
filepath = os.path.join(docs_dir, filename)
|
||||||
|
with open(filepath, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# Update the vector database to include all supported extensions
|
||||||
|
rag_pipeline.create_vector_db() # Use default extensions
|
||||||
|
|
||||||
# Check the database has been updated
|
# Check the database has been updated
|
||||||
updated_db = Chroma(
|
updated_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir,
|
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||||
embedding_function=rag_pipeline.ollama.embeddings,
|
|
||||||
)
|
)
|
||||||
assert updated_db._collection.count() > initial_count
|
assert updated_db._collection.count() > initial_count
|
||||||
|
|
||||||
|
# Verify new files are included
|
||||||
|
docs = updated_db._collection.get()
|
||||||
|
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||||
|
for filename in new_files.keys():
|
||||||
|
assert filename in sources, f"File {filename} not found in vector store"
|
||||||
|
|
||||||
|
|
||||||
# Final integration test - full RAG pipeline
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
not shutil.which("ollama"), reason="Ollama not installed or not in PATH"
|
||||||
)
|
)
|
||||||
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test the entire RAG pipeline from document processing to querying"""
|
"""Test the entire RAG pipeline from document processing to querying"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
# Create a specific document with known content
|
|
||||||
test_content = "Python is a high-level programming language known for its readability and versatility."
|
|
||||||
create_test_document(rag_pipeline.docs_dir, "python_info.txt", test_content)
|
|
||||||
|
|
||||||
# Create vector database
|
# Create documents with mixed content types
|
||||||
|
test_files = {
|
||||||
|
"python_info.py": """# Python Information
|
||||||
|
def describe_python():
|
||||||
|
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
|
||||||
|
pass""",
|
||||||
|
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
|
||||||
|
}
|
||||||
|
|
||||||
|
for filename, content in test_files.items():
|
||||||
|
filepath = os.path.join(rag_pipeline.docs_dir, filename)
|
||||||
|
with open(filepath, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# Create vector database with all default extensions
|
||||||
rag_pipeline.create_vector_db(force_refresh=True)
|
rag_pipeline.create_vector_db(force_refresh=True)
|
||||||
|
|
||||||
# Set up RAG
|
# Set up RAG
|
||||||
@ -124,4 +166,3 @@ def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
|||||||
# This is a soft test since the exact response will depend on the LLM
|
# This is a soft test since the exact response will depend on the LLM
|
||||||
assert response.strip() != ""
|
assert response.strip() != ""
|
||||||
assert "programming" in response.lower() or "language" in response.lower()
|
assert "programming" in response.lower() or "language" in response.lower()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user