features/cli #2
@ -24,6 +24,9 @@ ollama = "*"
|
|||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^8.3.5"
|
pytest = "^8.3.5"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
code-rag = "code_rag.cli:main"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
165
src/code_rag/cli.py
Normal file
165
src/code_rag/cli.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from code_rag.rag import RAG
|
||||||
|
|
||||||
|
def stream_output(response_iter):
|
||||||
|
"""Stream the response to console, handling each token"""
|
||||||
|
try:
|
||||||
|
for chunk in response_iter:
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
print() # New line at end
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStreaming interrupted by user")
|
||||||
|
|
||||||
|
|
||||||
|
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
|
||||||
|
"""Run an interactive chat session"""
|
||||||
|
print(
|
||||||
|
"\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\nQuestion: ").strip()
|
||||||
|
|
||||||
|
if query.lower() in ["exit", "quit", ""]:
|
||||||
|
print("\nEnding chat session.")
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\nResponse:")
|
||||||
|
if no_stream:
|
||||||
|
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
|
||||||
|
print(response)
|
||||||
|
else:
|
||||||
|
response_iter = rag_pipeline.query_rag(
|
||||||
|
rag_chain, query, stream=True
|
||||||
|
)
|
||||||
|
stream_output(response_iter)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nEnding chat session.")
|
||||||
|
break
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nChat session interrupted.")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Code RAG - Query your codebase using natural language"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"docs_dir", help="Directory containing the documents to process", type=str
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"query",
|
||||||
|
help="Initial query about your codebase (optional in interactive mode)",
|
||||||
|
nargs="?",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--db-dir",
|
||||||
|
help="Directory to store the vector database (default: .code_rag_db)",
|
||||||
|
default=os.path.expanduser("~/.code_rag_db"),
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tracker-file",
|
||||||
|
help="File to track document changes (default: .code_rag_tracker.json)",
|
||||||
|
default=os.path.expanduser("~/.code_rag_tracker.json"),
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-refresh",
|
||||||
|
help="Force refresh of the vector database",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ollama-url",
|
||||||
|
help="URL for the Ollama server (default: 127.0.0.1)",
|
||||||
|
default="127.0.0.1",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
help="Model to use for embeddings (default: nomic-embed-text)",
|
||||||
|
default="nomic-embed-text",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
help="Model to use for text generation (default: llama3.2)",
|
||||||
|
default="llama3.2",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-stream", help="Disable streaming output", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-interactive",
|
||||||
|
help="Run in non-interactive mode (answer single query and exit)",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.no_interactive and not args.query:
|
||||||
|
parser.error("Query is required in non-interactive mode")
|
||||||
|
|
||||||
|
# Create RAG pipeline
|
||||||
|
rag_pipeline = RAG(
|
||||||
|
docs_dir=args.docs_dir,
|
||||||
|
db_dir=args.db_dir,
|
||||||
|
tracker_file=args.tracker_file,
|
||||||
|
ollama_url=args.ollama_url,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
llm_model=args.llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create or update vector database
|
||||||
|
print(f"Processing documents in {args.docs_dir}...")
|
||||||
|
print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}")
|
||||||
|
print(f"Ollama server: {args.ollama_url}")
|
||||||
|
vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh)
|
||||||
|
|
||||||
|
# Set up RAG chain
|
||||||
|
print("Setting up RAG pipeline...")
|
||||||
|
rag_chain = rag_pipeline.setup_rag()
|
||||||
|
|
||||||
|
if args.no_interactive:
|
||||||
|
# Single query mode
|
||||||
|
print(f"\nQuery: {args.query}")
|
||||||
|
print("\nResponse:")
|
||||||
|
if args.no_stream:
|
||||||
|
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||||
|
stream_output(response)
|
||||||
|
else:
|
||||||
|
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
||||||
|
stream_output(response_iter)
|
||||||
|
else:
|
||||||
|
# Interactive mode
|
||||||
|
if args.query:
|
||||||
|
# Handle initial query if provided
|
||||||
|
print(f"\nQuery: {args.query}")
|
||||||
|
print("\nResponse:")
|
||||||
|
if args.no_stream:
|
||||||
|
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||||
|
stream_output(response)
|
||||||
|
else:
|
||||||
|
response_iter = rag_pipeline.query_rag(
|
||||||
|
rag_chain, args.query, stream=True
|
||||||
|
)
|
||||||
|
stream_output(response_iter)
|
||||||
|
|
||||||
|
# Start interactive chat
|
||||||
|
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,15 +1,24 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
import glob
|
import glob
|
||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
import uuid
|
||||||
|
from langchain.schema import Document
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_community.document_loaders import TextLoader
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
||||||
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||||
|
from langchain_community.chat_message_histories import ChatMessageHistory
|
||||||
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
from langchain.chains import RetrievalQA
|
import time
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
from code_rag.doc_tracker import DocumentTracker
|
from .doc_tracker import DocumentTracker
|
||||||
from code_rag.ollama_wrapper import OllamaWrapper
|
from .ollama_wrapper import OllamaWrapper
|
||||||
|
|
||||||
|
|
||||||
class RAG:
|
class RAG:
|
||||||
@ -28,6 +37,7 @@ class RAG:
|
|||||||
self.ollama = OllamaWrapper(
|
self.ollama = OllamaWrapper(
|
||||||
ollama_url, embedding_model=embedding_model, llm_model=llm_model
|
ollama_url, embedding_model=embedding_model, llm_model=llm_model
|
||||||
)
|
)
|
||||||
|
self.session_id = time.time()
|
||||||
|
|
||||||
def process_documents(self, files, text_splitter):
|
def process_documents(self, files, text_splitter):
|
||||||
"""Process document files into chunks with tracking metadata"""
|
"""Process document files into chunks with tracking metadata"""
|
||||||
@ -84,9 +94,7 @@ class RAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create embeddings
|
# Create embeddings
|
||||||
print("Before embedding")
|
|
||||||
embeddings = self.ollama.embeddings
|
embeddings = self.ollama.embeddings
|
||||||
print("after embedding")
|
|
||||||
|
|
||||||
# Load or create vector store
|
# Load or create vector store
|
||||||
if os.path.exists(self.db_dir) and not force_refresh:
|
if os.path.exists(self.db_dir) and not force_refresh:
|
||||||
@ -161,55 +169,117 @@ class RAG:
|
|||||||
return vectorstore
|
return vectorstore
|
||||||
|
|
||||||
def setup_rag(self):
|
def setup_rag(self):
|
||||||
"""
|
"""Set up the RAG pipeline"""
|
||||||
Set up the RAG system with an existing vector database
|
# Create vector store
|
||||||
"""
|
|
||||||
# Load the embeddings
|
|
||||||
embeddings = self.ollama.embeddings
|
|
||||||
|
|
||||||
# Load the vector store
|
|
||||||
vectorstore = Chroma(
|
vectorstore = Chroma(
|
||||||
persist_directory=self.db_dir, embedding_function=embeddings
|
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a retriever
|
# Create retriever
|
||||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
|
retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity", search_kwargs={"k": 4}
|
||||||
# Set up the LLM
|
|
||||||
llm = self.ollama.llm
|
|
||||||
|
|
||||||
# Create a custom prompt template
|
|
||||||
template = """
|
|
||||||
Answer the question based on the context provided. If you don't know the answer,
|
|
||||||
just say you don't know. Don't try to make up an answer.
|
|
||||||
|
|
||||||
Context: {context}
|
|
||||||
|
|
||||||
Question: {question}
|
|
||||||
|
|
||||||
Answer:
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
|
||||||
input_variables=["context", "question"], template=template
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the RAG chain
|
# Create chat history buffer
|
||||||
rag_chain = RetrievalQA.from_chain_type(
|
self.chat_history = []
|
||||||
llm=llm,
|
|
||||||
chain_type="stuff",
|
### Contextualize question ###
|
||||||
retriever=retriever,
|
contextualize_q_system_prompt = (
|
||||||
chain_type_kwargs={"prompt": prompt},
|
"Given a chat history and the latest user question "
|
||||||
|
"which might reference context in the chat history, "
|
||||||
|
"formulate a standalone question which can be understood "
|
||||||
|
"without the chat history. Do NOT answer the question, "
|
||||||
|
"just reformulate it if needed and otherwise return it as is."
|
||||||
|
)
|
||||||
|
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", contextualize_q_system_prompt),
|
||||||
|
MessagesPlaceholder("chat_history"),
|
||||||
|
("human", "{input}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
history_aware_retriever = create_history_aware_retriever(
|
||||||
|
self.ollama.llm, retriever, contextualize_q_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
return rag_chain
|
### Answer question ###
|
||||||
|
system_prompt = (
|
||||||
|
"You are an expert at analyzing code and documentation. "
|
||||||
|
"Use the following pieces of context to answer the question at the end."
|
||||||
|
"If you don't know the answer, just say that you don't know, "
|
||||||
|
"don't try to make up an answer"
|
||||||
|
"\n\n"
|
||||||
|
"{context}\n\n"
|
||||||
|
"Answer in a clear and concise manner. "
|
||||||
|
"If you're referring to code, use markdown formatting."
|
||||||
|
)
|
||||||
|
qa_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", system_prompt),
|
||||||
|
MessagesPlaceholder("chat_history"),
|
||||||
|
("human", "{input}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt)
|
||||||
|
|
||||||
def query_rag(self, rag_chain, query):
|
rag_chain = create_retrieval_chain(
|
||||||
|
history_aware_retriever, question_answer_chain
|
||||||
|
)
|
||||||
|
|
||||||
|
### Statefully manage chat history ###
|
||||||
|
self.store = {}
|
||||||
|
|
||||||
|
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
||||||
|
if session_id not in self.store:
|
||||||
|
self.store[session_id] = ChatMessageHistory()
|
||||||
|
return self.store[session_id]
|
||||||
|
|
||||||
|
return RunnableWithMessageHistory(
|
||||||
|
rag_chain,
|
||||||
|
get_session_history,
|
||||||
|
input_messages_key="input",
|
||||||
|
history_messages_key="chat_history",
|
||||||
|
output_messages_key="answer",
|
||||||
|
)
|
||||||
|
|
||||||
|
def query_rag(self, rag_chain, query, stream=False):
|
||||||
"""
|
"""
|
||||||
Query the RAG system
|
Query the RAG system
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rag_chain: The RAG chain to use
|
||||||
|
query: Query string
|
||||||
|
stream: If True, stream the response
|
||||||
"""
|
"""
|
||||||
response = rag_chain.invoke({"query": query})
|
if stream:
|
||||||
return response["result"]
|
response = rag_chain.stream(
|
||||||
|
{"input": query},
|
||||||
|
config={"configurable": {"session_id": self.session_id}},
|
||||||
|
)
|
||||||
|
# Store in chat history after getting full response
|
||||||
|
full_response = ""
|
||||||
|
for chunk in response:
|
||||||
|
# Extract only the LLM answer chunk
|
||||||
|
if "answer" in chunk:
|
||||||
|
chunk_text = (
|
||||||
|
chunk["answer"].content
|
||||||
|
if hasattr(chunk["answer"], "content")
|
||||||
|
else str(chunk["answer"])
|
||||||
|
)
|
||||||
|
full_response += chunk_text
|
||||||
|
yield chunk_text
|
||||||
|
else:
|
||||||
|
response = rag_chain.invoke(
|
||||||
|
{"input": query},
|
||||||
|
config={"configurable": {"session_id": self.session_id}},
|
||||||
|
)
|
||||||
|
# Extract only the LLM answer from the response
|
||||||
|
result = (
|
||||||
|
response["answer"].content
|
||||||
|
if hasattr(response["answer"], "content")
|
||||||
|
else str(response["answer"])
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def default_file_extensions():
|
def default_file_extensions():
|
||||||
@ -217,11 +287,6 @@ def default_file_extensions():
|
|||||||
return [
|
return [
|
||||||
# Python
|
# Python
|
||||||
".py",
|
".py",
|
||||||
".pyi",
|
|
||||||
".pyx",
|
|
||||||
".pyc",
|
|
||||||
".pyd",
|
|
||||||
".pyw",
|
|
||||||
# C/C++
|
# C/C++
|
||||||
".c",
|
".c",
|
||||||
".cpp",
|
".cpp",
|
||||||
@ -235,14 +300,8 @@ def default_file_extensions():
|
|||||||
".ipp",
|
".ipp",
|
||||||
# Rust
|
# Rust
|
||||||
".rs",
|
".rs",
|
||||||
".rlib",
|
|
||||||
".rmeta",
|
|
||||||
# Java
|
# Java
|
||||||
".java",
|
".java",
|
||||||
".jsp",
|
|
||||||
".jav",
|
|
||||||
".jar",
|
|
||||||
".class",
|
|
||||||
".kt",
|
".kt",
|
||||||
".kts",
|
".kts",
|
||||||
".groovy",
|
".groovy",
|
||||||
@ -250,9 +309,6 @@ def default_file_extensions():
|
|||||||
".html",
|
".html",
|
||||||
".htm",
|
".htm",
|
||||||
".css",
|
".css",
|
||||||
".scss",
|
|
||||||
".sass",
|
|
||||||
".less",
|
|
||||||
".js",
|
".js",
|
||||||
".jsx",
|
".jsx",
|
||||||
".ts",
|
".ts",
|
||||||
|
@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
|
|||||||
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test creating a vector database"""
|
"""Test creating a vector database"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
|
||||||
# Create files with different extensions
|
# Create files with different extensions
|
||||||
files = {
|
files = {
|
||||||
"test.py": "def hello():\n print('Hello, World!')",
|
"test.py": "def hello():\n print('Hello, World!')",
|
||||||
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
|
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
|
||||||
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
|
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
|
||||||
"config.toml": "[package]\nname = 'test'",
|
"config.toml": "[package]\nname = 'test'",
|
||||||
"doc.md": "# Documentation\nThis is a test file."
|
"doc.md": "# Documentation\nThis is a test file.",
|
||||||
}
|
}
|
||||||
|
|
||||||
for filename, content in files.items():
|
for filename, content in files.items():
|
||||||
filepath = os.path.join(docs_dir, filename)
|
filepath = os.path.join(docs_dir, filename)
|
||||||
with open(filepath, "w") as f:
|
with open(filepath, "w") as f:
|
||||||
@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
|||||||
# Verify it was created
|
# Verify it was created
|
||||||
assert os.path.exists(rag_pipeline.db_dir)
|
assert os.path.exists(rag_pipeline.db_dir)
|
||||||
assert vectorstore is not None
|
assert vectorstore is not None
|
||||||
|
|
||||||
# Check the database has content from all file types
|
# Check the database has content from all file types
|
||||||
loaded_db = Chroma(
|
loaded_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
persist_directory=rag_pipeline.db_dir,
|
||||||
|
embedding_function=rag_pipeline.ollama.embeddings,
|
||||||
)
|
)
|
||||||
# Should have content from all files
|
# Should have content from all files
|
||||||
assert loaded_db._collection.count() > 0
|
assert loaded_db._collection.count() > 0
|
||||||
|
|
||||||
# Verify each file type is included
|
# Verify each file type is included
|
||||||
docs = loaded_db._collection.get()
|
docs = loaded_db._collection.get()
|
||||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||||
@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
|||||||
"""Test updating a vector database with document changes"""
|
"""Test updating a vector database with document changes"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
# Create initial vector database with only Python files
|
# Create initial vector database with only Python files
|
||||||
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
|
vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True)
|
||||||
|
|
||||||
# Get initial count
|
# Get initial count
|
||||||
initial_db = Chroma(
|
initial_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
persist_directory=rag_pipeline.db_dir,
|
||||||
|
embedding_function=rag_pipeline.ollama.embeddings,
|
||||||
)
|
)
|
||||||
initial_count = initial_db._collection.count()
|
initial_count = initial_db._collection.count()
|
||||||
|
|
||||||
@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
|||||||
new_files = {
|
new_files = {
|
||||||
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
|
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
|
||||||
"lib.rs": "fn main() { println!('Hello'); }",
|
"lib.rs": "fn main() { println!('Hello'); }",
|
||||||
"config.toml": "[package]\nname = 'test'"
|
"config.toml": "[package]\nname = 'test'",
|
||||||
}
|
}
|
||||||
|
|
||||||
for filename, content in new_files.items():
|
for filename, content in new_files.items():
|
||||||
filepath = os.path.join(docs_dir, filename)
|
filepath = os.path.join(docs_dir, filename)
|
||||||
with open(filepath, "w") as f:
|
with open(filepath, "w") as f:
|
||||||
@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
|||||||
|
|
||||||
# Check the database has been updated
|
# Check the database has been updated
|
||||||
updated_db = Chroma(
|
updated_db = Chroma(
|
||||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
persist_directory=rag_pipeline.db_dir,
|
||||||
|
embedding_function=rag_pipeline.ollama.embeddings,
|
||||||
)
|
)
|
||||||
assert updated_db._collection.count() > initial_count
|
assert updated_db._collection.count() > initial_count
|
||||||
|
|
||||||
# Verify new files are included
|
# Verify new files are included
|
||||||
docs = updated_db._collection.get()
|
docs = updated_db._collection.get()
|
||||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||||
@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
|||||||
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
||||||
"""Test the entire RAG pipeline from document processing to querying"""
|
"""Test the entire RAG pipeline from document processing to querying"""
|
||||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||||
|
|
||||||
# Create documents with mixed content types
|
# Create documents with mixed content types
|
||||||
test_files = {
|
test_files = {
|
||||||
"python_info.py": """# Python Information
|
"python_info.py": """# Python Information
|
||||||
def describe_python():
|
def describe_python():
|
||||||
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
|
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
|
||||||
pass""",
|
pass""",
|
||||||
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
|
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.",
|
||||||
}
|
}
|
||||||
|
|
||||||
for filename, content in test_files.items():
|
for filename, content in test_files.items():
|
||||||
filepath = os.path.join(rag_pipeline.docs_dir, filename)
|
filepath = os.path.join(rag_pipeline.docs_dir, filename)
|
||||||
with open(filepath, "w") as f:
|
with open(filepath, "w") as f:
|
||||||
@ -160,7 +163,10 @@ def describe_python():
|
|||||||
|
|
||||||
# Query the system
|
# Query the system
|
||||||
query = "What is Python?"
|
query = "What is Python?"
|
||||||
response = rag_pipeline.query_rag(rag_chain, query)
|
try:
|
||||||
|
response = next(rag_pipeline.query_rag(rag_chain, query))
|
||||||
|
except StopIteration as ex:
|
||||||
|
response = ex.value
|
||||||
|
|
||||||
# Check if response contains relevant information
|
# Check if response contains relevant information
|
||||||
# This is a soft test since the exact response will depend on the LLM
|
# This is a soft test since the exact response will depend on the LLM
|
||||||
|
Loading…
x
Reference in New Issue
Block a user