Add a cli, update RAG to keep chat history and stream output

This commit is contained in:
Alex Selimov 2025-03-22 22:16:14 -04:00
parent d497a62f7f
commit dda45e7155
3 changed files with 235 additions and 58 deletions

View File

@ -24,6 +24,9 @@ ollama = "*"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.5"
[tool.poetry.scripts]
code-rag = "code_rag.cli:main"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

160
src/code_rag/cli.py Normal file
View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
import os
import sys
import argparse
from code_rag.rag import RAG
def stream_output(response_iter):
"""Stream the response to console, handling each token"""
try:
for chunk in response_iter:
print(chunk, end="", flush=True)
print() # New line at end
except KeyboardInterrupt:
print("\nStreaming interrupted by user")
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
"""Run an interactive chat session"""
print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n")
try:
while True:
try:
query = input("\nQuestion: ").strip()
if query.lower() in ['exit', 'quit', '']:
print("\nEnding chat session.")
break
print("\nResponse:")
if no_stream:
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True)
stream_output(response_iter)
except KeyboardInterrupt:
print("\n\nEnding chat session.")
break
except EOFError:
break
except KeyboardInterrupt:
print("\nChat session interrupted.")
def main():
parser = argparse.ArgumentParser(
description="Code RAG - Query your codebase using natural language"
)
parser.add_argument(
"docs_dir",
help="Directory containing the documents to process",
type=str
)
parser.add_argument(
"query",
help="Initial query about your codebase (optional in interactive mode)",
nargs="?",
default=None,
type=str
)
parser.add_argument(
"--db-dir",
help="Directory to store the vector database (default: .code_rag_db)",
default=os.path.expanduser("~/.code_rag_db"),
type=str
)
parser.add_argument(
"--tracker-file",
help="File to track document changes (default: .code_rag_tracker.json)",
default=os.path.expanduser("~/.code_rag_tracker.json"),
type=str
)
parser.add_argument(
"--force-refresh",
help="Force refresh of the vector database",
action="store_true"
)
parser.add_argument(
"--ollama-url",
help="URL for the Ollama server (default: 127.0.0.1)",
default="127.0.0.1",
type=str
)
parser.add_argument(
"--embedding-model",
help="Model to use for embeddings (default: nomic-embed-text)",
default="nomic-embed-text",
type=str
)
parser.add_argument(
"--llm-model",
help="Model to use for text generation (default: llama3.2)",
default="llama3.2",
type=str
)
parser.add_argument(
"--no-stream",
help="Disable streaming output",
action="store_true"
)
parser.add_argument(
"--no-interactive",
help="Run in non-interactive mode (answer single query and exit)",
action="store_true"
)
args = parser.parse_args()
if args.no_interactive and not args.query:
parser.error("Query is required in non-interactive mode")
# Create RAG pipeline
rag_pipeline = RAG(
docs_dir=args.docs_dir,
db_dir=args.db_dir,
tracker_file=args.tracker_file,
ollama_url=args.ollama_url,
embedding_model=args.embedding_model,
llm_model=args.llm_model
)
# Create or update vector database
print(f"Processing documents in {args.docs_dir}...")
print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}")
print(f"Ollama server: {args.ollama_url}")
vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh)
# Set up RAG chain
print("Setting up RAG pipeline...")
rag_chain = rag_pipeline.setup_rag()
if args.no_interactive:
# Single query mode
print(f"\nQuery: {args.query}")
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
stream_output(response_iter)
else:
# Interactive mode
if args.query:
# Handle initial query if provided
print(f"\nQuery: {args.query}")
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
stream_output(response_iter)
# Start interactive chat
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
if __name__ == "__main__":
main()

View File

@ -1,15 +1,17 @@
import os
import uuid
import glob
from langchain_community.document_loaders import DirectoryLoader, TextLoader
import uuid
from typing import List, Optional
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from code_rag.doc_tracker import DocumentTracker
from code_rag.ollama_wrapper import OllamaWrapper
from .doc_tracker import DocumentTracker
from .ollama_wrapper import OllamaWrapper
class RAG:
@ -161,55 +163,81 @@ class RAG:
return vectorstore
def setup_rag(self):
"""
Set up the RAG system with an existing vector database
"""
# Load the embeddings
embeddings = self.ollama.embeddings
# Load the vector store
"""Set up the RAG pipeline"""
# Create vector store
vectorstore = Chroma(
persist_directory=self.db_dir, embedding_function=embeddings
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
)
# Create a retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Set up the LLM
llm = self.ollama.llm
# Create a custom prompt template
template = """
Answer the question based on the context provided. If you don't know the answer,
just say you don't know. Don't try to make up an answer.
Context: {context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
input_variables=["context", "question"], template=template
# Create retriever
retriever = vectorstore.as_retriever(
search_type="similarity", search_kwargs={"k": 4}
)
# Create the RAG chain
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
# Create chat history buffer
self.chat_history = []
# Create RAG chain with chat history
system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context:
{context}
Chat History:
{chat_history}
Answer in a clear and concise manner. If you're referring to code, use markdown formatting."""
human_template = "{query}"
prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("human", human_template),
]
)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{
"context": lambda x: format_docs(retriever.invoke(x["query"])),
"chat_history": lambda x: "\n".join(self.chat_history),
"query": lambda x: x["query"],
}
| prompt
| self.ollama.llm
)
return rag_chain
def query_rag(self, rag_chain, query):
def query_rag(self, rag_chain, query, stream=False):
"""
Query the RAG system
Args:
rag_chain: The RAG chain to use
query: Query string
stream: If True, stream the response
"""
response = rag_chain.invoke({"query": query})
return response["result"]
if stream:
response = rag_chain.stream({"query": query})
# Store in chat history after getting full response
full_response = ""
for chunk in response:
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
full_response += chunk_text
yield chunk_text
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {full_response}")
else:
response = rag_chain({"query": query})
result = response.content if hasattr(response, "content") else str(response)
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {result}")
return result
def default_file_extensions():
@ -217,11 +245,6 @@ def default_file_extensions():
return [
# Python
".py",
".pyi",
".pyx",
".pyc",
".pyd",
".pyw",
# C/C++
".c",
".cpp",
@ -235,14 +258,8 @@ def default_file_extensions():
".ipp",
# Rust
".rs",
".rlib",
".rmeta",
# Java
".java",
".jsp",
".jav",
".jar",
".class",
".kt",
".kts",
".groovy",
@ -250,9 +267,6 @@ def default_file_extensions():
".html",
".htm",
".css",
".scss",
".sass",
".less",
".js",
".jsx",
".ts",