Add a cli, update RAG to keep chat history and stream output
This commit is contained in:
parent
d497a62f7f
commit
dda45e7155
@ -24,6 +24,9 @@ ollama = "*"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.3.5"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
code-rag = "code_rag.cli:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
160
src/code_rag/cli.py
Normal file
160
src/code_rag/cli.py
Normal file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from code_rag.rag import RAG
|
||||
|
||||
def stream_output(response_iter):
|
||||
"""Stream the response to console, handling each token"""
|
||||
try:
|
||||
for chunk in response_iter:
|
||||
print(chunk, end="", flush=True)
|
||||
print() # New line at end
|
||||
except KeyboardInterrupt:
|
||||
print("\nStreaming interrupted by user")
|
||||
|
||||
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
|
||||
"""Run an interactive chat session"""
|
||||
print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
query = input("\nQuestion: ").strip()
|
||||
|
||||
if query.lower() in ['exit', 'quit', '']:
|
||||
print("\nEnding chat session.")
|
||||
break
|
||||
|
||||
print("\nResponse:")
|
||||
if no_stream:
|
||||
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
|
||||
print(response)
|
||||
else:
|
||||
response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True)
|
||||
stream_output(response_iter)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nEnding chat session.")
|
||||
break
|
||||
except EOFError:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
print("\nChat session interrupted.")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Code RAG - Query your codebase using natural language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"docs_dir",
|
||||
help="Directory containing the documents to process",
|
||||
type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"query",
|
||||
help="Initial query about your codebase (optional in interactive mode)",
|
||||
nargs="?",
|
||||
default=None,
|
||||
type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-dir",
|
||||
help="Directory to store the vector database (default: .code_rag_db)",
|
||||
default=os.path.expanduser("~/.code_rag_db"),
|
||||
type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker-file",
|
||||
help="File to track document changes (default: .code_rag_tracker.json)",
|
||||
default=os.path.expanduser("~/.code_rag_tracker.json"),
|
||||
type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-refresh",
|
||||
help="Force refresh of the vector database",
|
||||
action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-url",
|
||||
help="URL for the Ollama server (default: 127.0.0.1)",
|
||||
default="127.0.0.1",
|
||||
type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
help="Model to use for embeddings (default: nomic-embed-text)",
|
||||
default="nomic-embed-text",
|
||||
type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model",
|
||||
help="Model to use for text generation (default: llama3.2)",
|
||||
default="llama3.2",
|
||||
type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
help="Disable streaming output",
|
||||
action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-interactive",
|
||||
help="Run in non-interactive mode (answer single query and exit)",
|
||||
action="store_true"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.no_interactive and not args.query:
|
||||
parser.error("Query is required in non-interactive mode")
|
||||
|
||||
# Create RAG pipeline
|
||||
rag_pipeline = RAG(
|
||||
docs_dir=args.docs_dir,
|
||||
db_dir=args.db_dir,
|
||||
tracker_file=args.tracker_file,
|
||||
ollama_url=args.ollama_url,
|
||||
embedding_model=args.embedding_model,
|
||||
llm_model=args.llm_model
|
||||
)
|
||||
|
||||
# Create or update vector database
|
||||
print(f"Processing documents in {args.docs_dir}...")
|
||||
print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}")
|
||||
print(f"Ollama server: {args.ollama_url}")
|
||||
vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh)
|
||||
|
||||
# Set up RAG chain
|
||||
print("Setting up RAG pipeline...")
|
||||
rag_chain = rag_pipeline.setup_rag()
|
||||
|
||||
if args.no_interactive:
|
||||
# Single query mode
|
||||
print(f"\nQuery: {args.query}")
|
||||
print("\nResponse:")
|
||||
if args.no_stream:
|
||||
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||
print(response)
|
||||
else:
|
||||
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
||||
stream_output(response_iter)
|
||||
else:
|
||||
# Interactive mode
|
||||
if args.query:
|
||||
# Handle initial query if provided
|
||||
print(f"\nQuery: {args.query}")
|
||||
print("\nResponse:")
|
||||
if args.no_stream:
|
||||
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||
print(response)
|
||||
else:
|
||||
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
||||
stream_output(response_iter)
|
||||
|
||||
# Start interactive chat
|
||||
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,15 +1,17 @@
|
||||
import os
|
||||
import uuid
|
||||
import glob
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_chroma import Chroma
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from code_rag.doc_tracker import DocumentTracker
|
||||
from code_rag.ollama_wrapper import OllamaWrapper
|
||||
from .doc_tracker import DocumentTracker
|
||||
from .ollama_wrapper import OllamaWrapper
|
||||
|
||||
|
||||
class RAG:
|
||||
@ -161,55 +163,81 @@ class RAG:
|
||||
return vectorstore
|
||||
|
||||
def setup_rag(self):
|
||||
"""
|
||||
Set up the RAG system with an existing vector database
|
||||
"""
|
||||
# Load the embeddings
|
||||
embeddings = self.ollama.embeddings
|
||||
|
||||
# Load the vector store
|
||||
"""Set up the RAG pipeline"""
|
||||
# Create vector store
|
||||
vectorstore = Chroma(
|
||||
persist_directory=self.db_dir, embedding_function=embeddings
|
||||
persist_directory=self.db_dir, embedding_function=self.ollama.embeddings
|
||||
)
|
||||
|
||||
# Create a retriever
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
|
||||
|
||||
# Set up the LLM
|
||||
llm = self.ollama.llm
|
||||
|
||||
# Create a custom prompt template
|
||||
template = """
|
||||
Answer the question based on the context provided. If you don't know the answer,
|
||||
just say you don't know. Don't try to make up an answer.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["context", "question"], template=template
|
||||
# Create retriever
|
||||
retriever = vectorstore.as_retriever(
|
||||
search_type="similarity", search_kwargs={"k": 4}
|
||||
)
|
||||
|
||||
# Create the RAG chain
|
||||
rag_chain = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type="stuff",
|
||||
retriever=retriever,
|
||||
chain_type_kwargs={"prompt": prompt},
|
||||
# Create chat history buffer
|
||||
self.chat_history = []
|
||||
|
||||
# Create RAG chain with chat history
|
||||
system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Answer in a clear and concise manner. If you're referring to code, use markdown formatting."""
|
||||
|
||||
human_template = "{query}"
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_template),
|
||||
("human", human_template),
|
||||
]
|
||||
)
|
||||
|
||||
def format_docs(docs):
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
rag_chain = (
|
||||
{
|
||||
"context": lambda x: format_docs(retriever.invoke(x["query"])),
|
||||
"chat_history": lambda x: "\n".join(self.chat_history),
|
||||
"query": lambda x: x["query"],
|
||||
}
|
||||
| prompt
|
||||
| self.ollama.llm
|
||||
)
|
||||
|
||||
return rag_chain
|
||||
|
||||
def query_rag(self, rag_chain, query):
|
||||
def query_rag(self, rag_chain, query, stream=False):
|
||||
"""
|
||||
Query the RAG system
|
||||
|
||||
Args:
|
||||
rag_chain: The RAG chain to use
|
||||
query: Query string
|
||||
stream: If True, stream the response
|
||||
"""
|
||||
response = rag_chain.invoke({"query": query})
|
||||
return response["result"]
|
||||
if stream:
|
||||
response = rag_chain.stream({"query": query})
|
||||
# Store in chat history after getting full response
|
||||
full_response = ""
|
||||
for chunk in response:
|
||||
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
|
||||
full_response += chunk_text
|
||||
yield chunk_text
|
||||
self.chat_history.append(f"Human: {query}")
|
||||
self.chat_history.append(f"Assistant: {full_response}")
|
||||
else:
|
||||
response = rag_chain({"query": query})
|
||||
result = response.content if hasattr(response, "content") else str(response)
|
||||
self.chat_history.append(f"Human: {query}")
|
||||
self.chat_history.append(f"Assistant: {result}")
|
||||
return result
|
||||
|
||||
|
||||
def default_file_extensions():
|
||||
@ -217,11 +245,6 @@ def default_file_extensions():
|
||||
return [
|
||||
# Python
|
||||
".py",
|
||||
".pyi",
|
||||
".pyx",
|
||||
".pyc",
|
||||
".pyd",
|
||||
".pyw",
|
||||
# C/C++
|
||||
".c",
|
||||
".cpp",
|
||||
@ -235,14 +258,8 @@ def default_file_extensions():
|
||||
".ipp",
|
||||
# Rust
|
||||
".rs",
|
||||
".rlib",
|
||||
".rmeta",
|
||||
# Java
|
||||
".java",
|
||||
".jsp",
|
||||
".jav",
|
||||
".jar",
|
||||
".class",
|
||||
".kt",
|
||||
".kts",
|
||||
".groovy",
|
||||
@ -250,9 +267,6 @@ def default_file_extensions():
|
||||
".html",
|
||||
".htm",
|
||||
".css",
|
||||
".scss",
|
||||
".sass",
|
||||
".less",
|
||||
".js",
|
||||
".jsx",
|
||||
".ts",
|
||||
|
Loading…
x
Reference in New Issue
Block a user