Add a cli, update RAG to keep chat history and stream output

This commit is contained in:
Alex Selimov 2025-03-22 22:16:14 -04:00
parent d497a62f7f
commit dda45e7155
3 changed files with 235 additions and 58 deletions

View File

@ -24,6 +24,9 @@ ollama = "*"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^8.3.5" pytest = "^8.3.5"
[tool.poetry.scripts]
code-rag = "code_rag.cli:main"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

160
src/code_rag/cli.py Normal file
View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
import os
import sys
import argparse
from code_rag.rag import RAG
def stream_output(response_iter):
"""Stream the response to console, handling each token"""
try:
for chunk in response_iter:
print(chunk, end="", flush=True)
print() # New line at end
except KeyboardInterrupt:
print("\nStreaming interrupted by user")
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
"""Run an interactive chat session"""
print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n")
try:
while True:
try:
query = input("\nQuestion: ").strip()
if query.lower() in ['exit', 'quit', '']:
print("\nEnding chat session.")
break
print("\nResponse:")
if no_stream:
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True)
stream_output(response_iter)
except KeyboardInterrupt:
print("\n\nEnding chat session.")
break
except EOFError:
break
except KeyboardInterrupt:
print("\nChat session interrupted.")
def main():
parser = argparse.ArgumentParser(
description="Code RAG - Query your codebase using natural language"
)
parser.add_argument(
"docs_dir",
help="Directory containing the documents to process",
type=str
)
parser.add_argument(
"query",
help="Initial query about your codebase (optional in interactive mode)",
nargs="?",
default=None,
type=str
)
parser.add_argument(
"--db-dir",
help="Directory to store the vector database (default: .code_rag_db)",
default=os.path.expanduser("~/.code_rag_db"),
type=str
)
parser.add_argument(
"--tracker-file",
help="File to track document changes (default: .code_rag_tracker.json)",
default=os.path.expanduser("~/.code_rag_tracker.json"),
type=str
)
parser.add_argument(
"--force-refresh",
help="Force refresh of the vector database",
action="store_true"
)
parser.add_argument(
"--ollama-url",
help="URL for the Ollama server (default: 127.0.0.1)",
default="127.0.0.1",
type=str
)
parser.add_argument(
"--embedding-model",
help="Model to use for embeddings (default: nomic-embed-text)",
default="nomic-embed-text",
type=str
)
parser.add_argument(
"--llm-model",
help="Model to use for text generation (default: llama3.2)",
default="llama3.2",
type=str
)
parser.add_argument(
"--no-stream",
help="Disable streaming output",
action="store_true"
)
parser.add_argument(
"--no-interactive",
help="Run in non-interactive mode (answer single query and exit)",
action="store_true"
)
args = parser.parse_args()
if args.no_interactive and not args.query:
parser.error("Query is required in non-interactive mode")
# Create RAG pipeline
rag_pipeline = RAG(
docs_dir=args.docs_dir,
db_dir=args.db_dir,
tracker_file=args.tracker_file,
ollama_url=args.ollama_url,
embedding_model=args.embedding_model,
llm_model=args.llm_model
)
# Create or update vector database
print(f"Processing documents in {args.docs_dir}...")
print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}")
print(f"Ollama server: {args.ollama_url}")
vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh)
# Set up RAG chain
print("Setting up RAG pipeline...")
rag_chain = rag_pipeline.setup_rag()
if args.no_interactive:
# Single query mode
print(f"\nQuery: {args.query}")
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
stream_output(response_iter)
else:
# Interactive mode
if args.query:
# Handle initial query if provided
print(f"\nQuery: {args.query}")
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
stream_output(response_iter)
# Start interactive chat
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
if __name__ == "__main__":
main()

View File

@ -1,15 +1,17 @@
import os import os
import uuid
import glob import glob
from langchain_community.document_loaders import DirectoryLoader, TextLoader import uuid
from typing import List, Optional
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_chroma import Chroma from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from code_rag.doc_tracker import DocumentTracker from .doc_tracker import DocumentTracker
from code_rag.ollama_wrapper import OllamaWrapper from .ollama_wrapper import OllamaWrapper
class RAG: class RAG:
@ -161,55 +163,81 @@ class RAG:
return vectorstore return vectorstore
def setup_rag(self): def setup_rag(self):
""" """Set up the RAG pipeline"""
Set up the RAG system with an existing vector database # Create vector store
"""
# Load the embeddings
embeddings = self.ollama.embeddings
# Load the vector store
vectorstore = Chroma( vectorstore = Chroma(
persist_directory=self.db_dir, embedding_function=embeddings persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
) )
# Create a retriever # Create retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) retriever = vectorstore.as_retriever(
search_type="similarity", search_kwargs={"k": 4}
# Set up the LLM
llm = self.ollama.llm
# Create a custom prompt template
template = """
Answer the question based on the context provided. If you don't know the answer,
just say you don't know. Don't try to make up an answer.
Context: {context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
input_variables=["context", "question"], template=template
) )
# Create the RAG chain # Create chat history buffer
rag_chain = RetrievalQA.from_chain_type( self.chat_history = []
llm=llm,
chain_type="stuff", # Create RAG chain with chat history
retriever=retriever, system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end.
chain_type_kwargs={"prompt": prompt}, If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context:
{context}
Chat History:
{chat_history}
Answer in a clear and concise manner. If you're referring to code, use markdown formatting."""
human_template = "{query}"
prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("human", human_template),
]
)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{
"context": lambda x: format_docs(retriever.invoke(x["query"])),
"chat_history": lambda x: "\n".join(self.chat_history),
"query": lambda x: x["query"],
}
| prompt
| self.ollama.llm
) )
return rag_chain return rag_chain
def query_rag(self, rag_chain, query): def query_rag(self, rag_chain, query, stream=False):
""" """
Query the RAG system Query the RAG system
Args:
rag_chain: The RAG chain to use
query: Query string
stream: If True, stream the response
""" """
response = rag_chain.invoke({"query": query}) if stream:
return response["result"] response = rag_chain.stream({"query": query})
# Store in chat history after getting full response
full_response = ""
for chunk in response:
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
full_response += chunk_text
yield chunk_text
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {full_response}")
else:
response = rag_chain({"query": query})
result = response.content if hasattr(response, "content") else str(response)
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {result}")
return result
def default_file_extensions(): def default_file_extensions():
@ -217,11 +245,6 @@ def default_file_extensions():
return [ return [
# Python # Python
".py", ".py",
".pyi",
".pyx",
".pyc",
".pyd",
".pyw",
# C/C++ # C/C++
".c", ".c",
".cpp", ".cpp",
@ -235,14 +258,8 @@ def default_file_extensions():
".ipp", ".ipp",
# Rust # Rust
".rs", ".rs",
".rlib",
".rmeta",
# Java # Java
".java", ".java",
".jsp",
".jav",
".jar",
".class",
".kt", ".kt",
".kts", ".kts",
".groovy", ".groovy",
@ -250,9 +267,6 @@ def default_file_extensions():
".html", ".html",
".htm", ".htm",
".css", ".css",
".scss",
".sass",
".less",
".js", ".js",
".jsx", ".jsx",
".ts", ".ts",