features/cli #2

Merged
aselimov merged 3 commits from features/cli into master 2025-03-24 00:07:27 +00:00
4 changed files with 307 additions and 77 deletions

View File

@ -24,6 +24,9 @@ ollama = "*"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.5"
[tool.poetry.scripts]
code-rag = "code_rag.cli:main"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

165
src/code_rag/cli.py Normal file
View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
import os
import sys
import argparse
from code_rag.rag import RAG
def stream_output(response_iter):
"""Stream the response to console, handling each token"""
try:
for chunk in response_iter:
print(chunk, end="", flush=True)
print() # New line at end
except KeyboardInterrupt:
print("\nStreaming interrupted by user")
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
"""Run an interactive chat session"""
print(
"\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n"
)
try:
while True:
try:
query = input("\nQuestion: ").strip()
if query.lower() in ["exit", "quit", ""]:
print("\nEnding chat session.")
break
print("\nResponse:")
if no_stream:
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(
rag_chain, query, stream=True
)
stream_output(response_iter)
except KeyboardInterrupt:
print("\n\nEnding chat session.")
break
except EOFError:
break
except KeyboardInterrupt:
print("\nChat session interrupted.")
def main():
parser = argparse.ArgumentParser(
description="Code RAG - Query your codebase using natural language"
)
parser.add_argument(
"docs_dir", help="Directory containing the documents to process", type=str
)
parser.add_argument(
"query",
help="Initial query about your codebase (optional in interactive mode)",
nargs="?",
default=None,
type=str,
)
parser.add_argument(
"--db-dir",
help="Directory to store the vector database (default: .code_rag_db)",
default=os.path.expanduser("~/.code_rag_db"),
type=str,
)
parser.add_argument(
"--tracker-file",
help="File to track document changes (default: .code_rag_tracker.json)",
default=os.path.expanduser("~/.code_rag_tracker.json"),
type=str,
)
parser.add_argument(
"--force-refresh",
help="Force refresh of the vector database",
action="store_true",
)
parser.add_argument(
"--ollama-url",
help="URL for the Ollama server (default: 127.0.0.1)",
default="127.0.0.1",
type=str,
)
parser.add_argument(
"--embedding-model",
help="Model to use for embeddings (default: nomic-embed-text)",
default="nomic-embed-text",
type=str,
)
parser.add_argument(
"--llm-model",
help="Model to use for text generation (default: llama3.2)",
default="llama3.2",
type=str,
)
parser.add_argument(
"--no-stream", help="Disable streaming output", action="store_true"
)
parser.add_argument(
"--no-interactive",
help="Run in non-interactive mode (answer single query and exit)",
action="store_true",
)
args = parser.parse_args()
if args.no_interactive and not args.query:
parser.error("Query is required in non-interactive mode")
# Create RAG pipeline
rag_pipeline = RAG(
docs_dir=args.docs_dir,
db_dir=args.db_dir,
tracker_file=args.tracker_file,
ollama_url=args.ollama_url,
embedding_model=args.embedding_model,
llm_model=args.llm_model,
)
# Create or update vector database
print(f"Processing documents in {args.docs_dir}...")
print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}")
print(f"Ollama server: {args.ollama_url}")
vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh)
# Set up RAG chain
print("Setting up RAG pipeline...")
rag_chain = rag_pipeline.setup_rag()
if args.no_interactive:
# Single query mode
print(f"\nQuery: {args.query}")
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
stream_output(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
stream_output(response_iter)
else:
# Interactive mode
if args.query:
# Handle initial query if provided
print(f"\nQuery: {args.query}")
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
stream_output(response)
else:
response_iter = rag_pipeline.query_rag(
rag_chain, args.query, stream=True
)
stream_output(response_iter)
# Start interactive chat
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
if __name__ == "__main__":
main()

View File

@ -1,15 +1,24 @@
import os
import uuid
import glob
from langchain_community.document_loaders import DirectoryLoader, TextLoader
import uuid
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
import time
from code_rag.doc_tracker import DocumentTracker
from code_rag.ollama_wrapper import OllamaWrapper
from .doc_tracker import DocumentTracker
from .ollama_wrapper import OllamaWrapper
class RAG:
@ -28,6 +37,7 @@ class RAG:
self.ollama = OllamaWrapper(
ollama_url, embedding_model=embedding_model, llm_model=llm_model
)
self.session_id = time.time()
def process_documents(self, files, text_splitter):
"""Process document files into chunks with tracking metadata"""
@ -84,9 +94,7 @@ class RAG:
)
# Create embeddings
print("Before embedding")
embeddings = self.ollama.embeddings
print("after embedding")
# Load or create vector store
if os.path.exists(self.db_dir) and not force_refresh:
@ -161,55 +169,117 @@ class RAG:
return vectorstore
def setup_rag(self):
"""
Set up the RAG system with an existing vector database
"""
# Load the embeddings
embeddings = self.ollama.embeddings
# Load the vector store
"""Set up the RAG pipeline"""
# Create vector store
vectorstore = Chroma(
persist_directory=self.db_dir, embedding_function=embeddings
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
)
# Create a retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Set up the LLM
llm = self.ollama.llm
# Create a custom prompt template
template = """
Answer the question based on the context provided. If you don't know the answer,
just say you don't know. Don't try to make up an answer.
Context: {context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
input_variables=["context", "question"], template=template
# Create retriever
retriever = vectorstore.as_retriever(
search_type="similarity", search_kwargs={"k": 4}
)
# Create the RAG chain
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
# Create chat history buffer
self.chat_history = []
### Contextualize question ###
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
self.ollama.llm, retriever, contextualize_q_prompt
)
return rag_chain
### Answer question ###
system_prompt = (
"You are an expert at analyzing code and documentation. "
"Use the following pieces of context to answer the question at the end."
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer"
"\n\n"
"{context}\n\n"
"Answer in a clear and concise manner. "
"If you're referring to code, use markdown formatting."
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt)
def query_rag(self, rag_chain, query):
rag_chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
### Statefully manage chat history ###
self.store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
return RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
def query_rag(self, rag_chain, query, stream=False):
"""
Query the RAG system
Args:
rag_chain: The RAG chain to use
query: Query string
stream: If True, stream the response
"""
response = rag_chain.invoke({"query": query})
return response["result"]
if stream:
response = rag_chain.stream(
{"input": query},
config={"configurable": {"session_id": self.session_id}},
)
# Store in chat history after getting full response
full_response = ""
for chunk in response:
# Extract only the LLM answer chunk
if "answer" in chunk:
chunk_text = (
chunk["answer"].content
if hasattr(chunk["answer"], "content")
else str(chunk["answer"])
)
full_response += chunk_text
yield chunk_text
else:
response = rag_chain.invoke(
{"input": query},
config={"configurable": {"session_id": self.session_id}},
)
# Extract only the LLM answer from the response
result = (
response["answer"].content
if hasattr(response["answer"], "content")
else str(response["answer"])
)
return result
def default_file_extensions():
@ -217,11 +287,6 @@ def default_file_extensions():
return [
# Python
".py",
".pyi",
".pyx",
".pyc",
".pyd",
".pyw",
# C/C++
".c",
".cpp",
@ -235,14 +300,8 @@ def default_file_extensions():
".ipp",
# Rust
".rs",
".rlib",
".rmeta",
# Java
".java",
".jsp",
".jav",
".jar",
".class",
".kt",
".kts",
".groovy",
@ -250,9 +309,6 @@ def default_file_extensions():
".html",
".htm",
".css",
".scss",
".sass",
".less",
".js",
".jsx",
".ts",

View File

@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
"""Test creating a vector database"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create files with different extensions
files = {
"test.py": "def hello():\n print('Hello, World!')",
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
"config.toml": "[package]\nname = 'test'",
"doc.md": "# Documentation\nThis is a test file."
"doc.md": "# Documentation\nThis is a test file.",
}
for filename, content in files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
# Verify it was created
assert os.path.exists(rag_pipeline.db_dir)
assert vectorstore is not None
# Check the database has content from all file types
loaded_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
)
# Should have content from all files
assert loaded_db._collection.count() > 0
# Verify each file type is included
docs = loaded_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
"""Test updating a vector database with document changes"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database with only Python files
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True)
# Get initial count
initial_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
)
initial_count = initial_db._collection.count()
@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
new_files = {
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
"lib.rs": "fn main() { println!('Hello'); }",
"config.toml": "[package]\nname = 'test'"
"config.toml": "[package]\nname = 'test'",
}
for filename, content in new_files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
# Check the database has been updated
updated_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
)
assert updated_db._collection.count() > initial_count
# Verify new files are included
docs = updated_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
"""Test the entire RAG pipeline from document processing to querying"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create documents with mixed content types
test_files = {
"python_info.py": """# Python Information
def describe_python():
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
pass""",
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.",
}
for filename, content in test_files.items():
filepath = os.path.join(rag_pipeline.docs_dir, filename)
with open(filepath, "w") as f:
@ -160,7 +163,10 @@ def describe_python():
# Query the system
query = "What is Python?"
response = rag_pipeline.query_rag(rag_chain, query)
try:
response = next(rag_pipeline.query_rag(rag_chain, query))
except StopIteration as ex:
response = ex.value
# Check if response contains relevant information
# This is a soft test since the exact response will depend on the LLM