diff --git a/pyproject.toml b/pyproject.toml index f81eb07..59f79db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ ollama = "*" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" +[tool.poetry.scripts] +code-rag = "code_rag.cli:main" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/src/code_rag/cli.py b/src/code_rag/cli.py new file mode 100644 index 0000000..1355c06 --- /dev/null +++ b/src/code_rag/cli.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 + +import os +import sys +import argparse +from code_rag.rag import RAG + +def stream_output(response_iter): + """Stream the response to console, handling each token""" + try: + for chunk in response_iter: + print(chunk, end="", flush=True) + print() # New line at end + except KeyboardInterrupt: + print("\nStreaming interrupted by user") + +def interactive_chat(rag_pipeline, rag_chain, no_stream=False): + """Run an interactive chat session""" + print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n") + + try: + while True: + try: + query = input("\nQuestion: ").strip() + + if query.lower() in ['exit', 'quit', '']: + print("\nEnding chat session.") + break + + print("\nResponse:") + if no_stream: + response = rag_pipeline.query_rag(rag_chain, query, stream=False) + print(response) + else: + response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True) + stream_output(response_iter) + + except KeyboardInterrupt: + print("\n\nEnding chat session.") + break + except EOFError: + break + except KeyboardInterrupt: + print("\nChat session interrupted.") + +def main(): + parser = argparse.ArgumentParser( + description="Code RAG - Query your codebase using natural language" + ) + parser.add_argument( + "docs_dir", + help="Directory containing the documents to process", + type=str + ) + parser.add_argument( + "query", + help="Initial query about your codebase (optional in interactive mode)", + nargs="?", + default=None, + type=str + ) + parser.add_argument( + "--db-dir", + help="Directory to store the vector database (default: .code_rag_db)", + default=os.path.expanduser("~/.code_rag_db"), + type=str + ) + parser.add_argument( + "--tracker-file", + help="File to track document changes (default: .code_rag_tracker.json)", + default=os.path.expanduser("~/.code_rag_tracker.json"), + type=str + ) + parser.add_argument( + "--force-refresh", + help="Force refresh of the vector database", + action="store_true" + ) + parser.add_argument( + "--ollama-url", + help="URL for the Ollama server (default: 127.0.0.1)", + default="127.0.0.1", + type=str + ) + parser.add_argument( + "--embedding-model", + help="Model to use for embeddings (default: nomic-embed-text)", + default="nomic-embed-text", + type=str + ) + parser.add_argument( + "--llm-model", + help="Model to use for text generation (default: llama3.2)", + default="llama3.2", + type=str + ) + parser.add_argument( + "--no-stream", + help="Disable streaming output", + action="store_true" + ) + parser.add_argument( + "--no-interactive", + help="Run in non-interactive mode (answer single query and exit)", + action="store_true" + ) + + args = parser.parse_args() + + if args.no_interactive and not args.query: + parser.error("Query is required in non-interactive mode") + + # Create RAG pipeline + rag_pipeline = RAG( + docs_dir=args.docs_dir, + db_dir=args.db_dir, + tracker_file=args.tracker_file, + ollama_url=args.ollama_url, + embedding_model=args.embedding_model, + llm_model=args.llm_model + ) + + # Create or update vector database + print(f"Processing documents in {args.docs_dir}...") + print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}") + print(f"Ollama server: {args.ollama_url}") + vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh) + + # Set up RAG chain + print("Setting up RAG pipeline...") + rag_chain = rag_pipeline.setup_rag() + + if args.no_interactive: + # Single query mode + print(f"\nQuery: {args.query}") + print("\nResponse:") + if args.no_stream: + response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) + print(response) + else: + response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) + stream_output(response_iter) + else: + # Interactive mode + if args.query: + # Handle initial query if provided + print(f"\nQuery: {args.query}") + print("\nResponse:") + if args.no_stream: + response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) + print(response) + else: + response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) + stream_output(response_iter) + + # Start interactive chat + interactive_chat(rag_pipeline, rag_chain, args.no_stream) + +if __name__ == "__main__": + main() diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index e1fd5f8..a4043a7 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -1,15 +1,17 @@ import os -import uuid import glob -from langchain_community.document_loaders import DirectoryLoader, TextLoader +import uuid +from typing import List, Optional +from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import TextLoader +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_core.messages import HumanMessage, SystemMessage from langchain_chroma import Chroma -from langchain.chains import RetrievalQA -from langchain.prompts import PromptTemplate -from langchain_core.documents import Document -from code_rag.doc_tracker import DocumentTracker -from code_rag.ollama_wrapper import OllamaWrapper +from .doc_tracker import DocumentTracker +from .ollama_wrapper import OllamaWrapper class RAG: @@ -161,55 +163,81 @@ class RAG: return vectorstore def setup_rag(self): - """ - Set up the RAG system with an existing vector database - """ - # Load the embeddings - embeddings = self.ollama.embeddings - - # Load the vector store + """Set up the RAG pipeline""" + # Create vector store vectorstore = Chroma( - persist_directory=self.db_dir, embedding_function=embeddings + persist_directory=self.db_dir, embedding_function=self.ollama.embeddings ) - # Create a retriever - retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) - - # Set up the LLM - llm = self.ollama.llm - - # Create a custom prompt template - template = """ - Answer the question based on the context provided. If you don't know the answer, - just say you don't know. Don't try to make up an answer. - - Context: {context} - - Question: {question} - - Answer: - """ - - prompt = PromptTemplate( - input_variables=["context", "question"], template=template + # Create retriever + retriever = vectorstore.as_retriever( + search_type="similarity", search_kwargs={"k": 4} ) - # Create the RAG chain - rag_chain = RetrievalQA.from_chain_type( - llm=llm, - chain_type="stuff", - retriever=retriever, - chain_type_kwargs={"prompt": prompt}, + # Create chat history buffer + self.chat_history = [] + + # Create RAG chain with chat history + system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end. +If you don't know the answer, just say that you don't know, don't try to make up an answer. + +Context: +{context} + +Chat History: +{chat_history} + +Answer in a clear and concise manner. If you're referring to code, use markdown formatting.""" + + human_template = "{query}" + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_template), + ("human", human_template), + ] + ) + + def format_docs(docs): + return "\n\n".join(doc.page_content for doc in docs) + + rag_chain = ( + { + "context": lambda x: format_docs(retriever.invoke(x["query"])), + "chat_history": lambda x: "\n".join(self.chat_history), + "query": lambda x: x["query"], + } + | prompt + | self.ollama.llm ) return rag_chain - def query_rag(self, rag_chain, query): + def query_rag(self, rag_chain, query, stream=False): """ Query the RAG system + + Args: + rag_chain: The RAG chain to use + query: Query string + stream: If True, stream the response """ - response = rag_chain.invoke({"query": query}) - return response["result"] + if stream: + response = rag_chain.stream({"query": query}) + # Store in chat history after getting full response + full_response = "" + for chunk in response: + chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) + full_response += chunk_text + yield chunk_text + self.chat_history.append(f"Human: {query}") + self.chat_history.append(f"Assistant: {full_response}") + else: + response = rag_chain({"query": query}) + result = response.content if hasattr(response, "content") else str(response) + self.chat_history.append(f"Human: {query}") + self.chat_history.append(f"Assistant: {result}") + return result def default_file_extensions(): @@ -217,11 +245,6 @@ def default_file_extensions(): return [ # Python ".py", - ".pyi", - ".pyx", - ".pyc", - ".pyd", - ".pyw", # C/C++ ".c", ".cpp", @@ -235,14 +258,8 @@ def default_file_extensions(): ".ipp", # Rust ".rs", - ".rlib", - ".rmeta", # Java ".java", - ".jsp", - ".jav", - ".jar", - ".class", ".kt", ".kts", ".groovy", @@ -250,9 +267,6 @@ def default_file_extensions(): ".html", ".htm", ".css", - ".scss", - ".sass", - ".less", ".js", ".jsx", ".ts",