diff --git a/pyproject.toml b/pyproject.toml index f81eb07..59f79db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ ollama = "*" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" +[tool.poetry.scripts] +code-rag = "code_rag.cli:main" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/src/code_rag/cli.py b/src/code_rag/cli.py new file mode 100644 index 0000000..4ec80ae --- /dev/null +++ b/src/code_rag/cli.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 + +import os +import sys +import argparse +from code_rag.rag import RAG + +def stream_output(response_iter): + """Stream the response to console, handling each token""" + try: + for chunk in response_iter: + print(chunk, end="", flush=True) + print() # New line at end + except KeyboardInterrupt: + print("\nStreaming interrupted by user") + + +def interactive_chat(rag_pipeline, rag_chain, no_stream=False): + """Run an interactive chat session""" + print( + "\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n" + ) + + try: + while True: + try: + query = input("\nQuestion: ").strip() + + if query.lower() in ["exit", "quit", ""]: + print("\nEnding chat session.") + break + + print("\nResponse:") + if no_stream: + response = rag_pipeline.query_rag(rag_chain, query, stream=False) + print(response) + else: + response_iter = rag_pipeline.query_rag( + rag_chain, query, stream=True + ) + stream_output(response_iter) + + except KeyboardInterrupt: + print("\n\nEnding chat session.") + break + except EOFError: + break + except KeyboardInterrupt: + print("\nChat session interrupted.") + + +def main(): + parser = argparse.ArgumentParser( + description="Code RAG - Query your codebase using natural language" + ) + parser.add_argument( + "docs_dir", help="Directory containing the documents to process", type=str + ) + parser.add_argument( + "query", + help="Initial query about your codebase (optional in interactive mode)", + nargs="?", + default=None, + type=str, + ) + parser.add_argument( + "--db-dir", + help="Directory to store the vector database (default: .code_rag_db)", + default=os.path.expanduser("~/.code_rag_db"), + type=str, + ) + parser.add_argument( + "--tracker-file", + help="File to track document changes (default: .code_rag_tracker.json)", + default=os.path.expanduser("~/.code_rag_tracker.json"), + type=str, + ) + parser.add_argument( + "--force-refresh", + help="Force refresh of the vector database", + action="store_true", + ) + parser.add_argument( + "--ollama-url", + help="URL for the Ollama server (default: 127.0.0.1)", + default="127.0.0.1", + type=str, + ) + parser.add_argument( + "--embedding-model", + help="Model to use for embeddings (default: nomic-embed-text)", + default="nomic-embed-text", + type=str, + ) + parser.add_argument( + "--llm-model", + help="Model to use for text generation (default: llama3.2)", + default="llama3.2", + type=str, + ) + parser.add_argument( + "--no-stream", help="Disable streaming output", action="store_true" + ) + parser.add_argument( + "--no-interactive", + help="Run in non-interactive mode (answer single query and exit)", + action="store_true", + ) + + args = parser.parse_args() + + if args.no_interactive and not args.query: + parser.error("Query is required in non-interactive mode") + + # Create RAG pipeline + rag_pipeline = RAG( + docs_dir=args.docs_dir, + db_dir=args.db_dir, + tracker_file=args.tracker_file, + ollama_url=args.ollama_url, + embedding_model=args.embedding_model, + llm_model=args.llm_model, + ) + + # Create or update vector database + print(f"Processing documents in {args.docs_dir}...") + print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}") + print(f"Ollama server: {args.ollama_url}") + vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh) + + # Set up RAG chain + print("Setting up RAG pipeline...") + rag_chain = rag_pipeline.setup_rag() + + if args.no_interactive: + # Single query mode + print(f"\nQuery: {args.query}") + print("\nResponse:") + if args.no_stream: + response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) + stream_output(response) + else: + response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) + stream_output(response_iter) + else: + # Interactive mode + if args.query: + # Handle initial query if provided + print(f"\nQuery: {args.query}") + print("\nResponse:") + if args.no_stream: + response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) + stream_output(response) + else: + response_iter = rag_pipeline.query_rag( + rag_chain, args.query, stream=True + ) + stream_output(response_iter) + + # Start interactive chat + interactive_chat(rag_pipeline, rag_chain, args.no_stream) + + +if __name__ == "__main__": + main() diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index e1fd5f8..9d68670 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -1,15 +1,24 @@ import os -import uuid import glob -from langchain_community.document_loaders import DirectoryLoader, TextLoader +import uuid +from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import TextLoader +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.chains import create_history_aware_retriever, create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_chroma import Chroma -from langchain.chains import RetrievalQA -from langchain.prompts import PromptTemplate -from langchain_core.documents import Document +import time -from code_rag.doc_tracker import DocumentTracker -from code_rag.ollama_wrapper import OllamaWrapper +from .doc_tracker import DocumentTracker +from .ollama_wrapper import OllamaWrapper class RAG: @@ -28,6 +37,7 @@ class RAG: self.ollama = OllamaWrapper( ollama_url, embedding_model=embedding_model, llm_model=llm_model ) + self.session_id = time.time() def process_documents(self, files, text_splitter): """Process document files into chunks with tracking metadata""" @@ -84,9 +94,7 @@ class RAG: ) # Create embeddings - print("Before embedding") embeddings = self.ollama.embeddings - print("after embedding") # Load or create vector store if os.path.exists(self.db_dir) and not force_refresh: @@ -161,55 +169,117 @@ class RAG: return vectorstore def setup_rag(self): - """ - Set up the RAG system with an existing vector database - """ - # Load the embeddings - embeddings = self.ollama.embeddings - - # Load the vector store + """Set up the RAG pipeline""" + # Create vector store vectorstore = Chroma( - persist_directory=self.db_dir, embedding_function=embeddings + persist_directory=self.db_dir, embedding_function=self.ollama.embeddings ) - # Create a retriever - retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) - - # Set up the LLM - llm = self.ollama.llm - - # Create a custom prompt template - template = """ - Answer the question based on the context provided. If you don't know the answer, - just say you don't know. Don't try to make up an answer. - - Context: {context} - - Question: {question} - - Answer: - """ - - prompt = PromptTemplate( - input_variables=["context", "question"], template=template + # Create retriever + retriever = vectorstore.as_retriever( + search_type="similarity", search_kwargs={"k": 4} ) - # Create the RAG chain - rag_chain = RetrievalQA.from_chain_type( - llm=llm, - chain_type="stuff", - retriever=retriever, - chain_type_kwargs={"prompt": prompt}, + # Create chat history buffer + self.chat_history = [] + + ### Contextualize question ### + contextualize_q_system_prompt = ( + "Given a chat history and the latest user question " + "which might reference context in the chat history, " + "formulate a standalone question which can be understood " + "without the chat history. Do NOT answer the question, " + "just reformulate it if needed and otherwise return it as is." + ) + contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + history_aware_retriever = create_history_aware_retriever( + self.ollama.llm, retriever, contextualize_q_prompt ) - return rag_chain + ### Answer question ### + system_prompt = ( + "You are an expert at analyzing code and documentation. " + "Use the following pieces of context to answer the question at the end." + "If you don't know the answer, just say that you don't know, " + "don't try to make up an answer" + "\n\n" + "{context}\n\n" + "Answer in a clear and concise manner. " + "If you're referring to code, use markdown formatting." + ) + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt) - def query_rag(self, rag_chain, query): + rag_chain = create_retrieval_chain( + history_aware_retriever, question_answer_chain + ) + + ### Statefully manage chat history ### + self.store = {} + + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in self.store: + self.store[session_id] = ChatMessageHistory() + return self.store[session_id] + + return RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", + ) + + def query_rag(self, rag_chain, query, stream=False): """ Query the RAG system + + Args: + rag_chain: The RAG chain to use + query: Query string + stream: If True, stream the response """ - response = rag_chain.invoke({"query": query}) - return response["result"] + if stream: + response = rag_chain.stream( + {"input": query}, + config={"configurable": {"session_id": self.session_id}}, + ) + # Store in chat history after getting full response + full_response = "" + for chunk in response: + # Extract only the LLM answer chunk + if "answer" in chunk: + chunk_text = ( + chunk["answer"].content + if hasattr(chunk["answer"], "content") + else str(chunk["answer"]) + ) + full_response += chunk_text + yield chunk_text + else: + response = rag_chain.invoke( + {"input": query}, + config={"configurable": {"session_id": self.session_id}}, + ) + # Extract only the LLM answer from the response + result = ( + response["answer"].content + if hasattr(response["answer"], "content") + else str(response["answer"]) + ) + return result def default_file_extensions(): @@ -217,11 +287,6 @@ def default_file_extensions(): return [ # Python ".py", - ".pyi", - ".pyx", - ".pyc", - ".pyd", - ".pyw", # C/C++ ".c", ".cpp", @@ -235,14 +300,8 @@ def default_file_extensions(): ".ipp", # Rust ".rs", - ".rlib", - ".rmeta", # Java ".java", - ".jsp", - ".jav", - ".jar", - ".class", ".kt", ".kts", ".groovy", @@ -250,9 +309,6 @@ def default_file_extensions(): ".html", ".htm", ".css", - ".scss", - ".sass", - ".less", ".js", ".jsx", ".ts", diff --git a/tests/test_rag.py b/tests/test_rag.py index c0bef74..e7547af 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): """Test creating a vector database""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) - + # Create files with different extensions files = { "test.py": "def hello():\n print('Hello, World!')", "main.cpp": "#include \nint main() { std::cout << 'Hello'; return 0; }", "lib.rs": "fn main() { println!('Hello from Rust!'); }", "config.toml": "[package]\nname = 'test'", - "doc.md": "# Documentation\nThis is a test file." + "doc.md": "# Documentation\nThis is a test file.", } - + for filename, content in files.items(): filepath = os.path.join(docs_dir, filename) with open(filepath, "w") as f: @@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): # Verify it was created assert os.path.exists(rag_pipeline.db_dir) assert vectorstore is not None - + # Check the database has content from all file types loaded_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) # Should have content from all files assert loaded_db._collection.count() > 0 - + # Verify each file type is included docs = loaded_db._collection.get() sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} @@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do """Test updating a vector database with document changes""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) # Create initial vector database with only Python files - vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True) + vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True) # Get initial count initial_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) initial_count = initial_db._collection.count() @@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do new_files = { "newdoc.cpp": "#include \nint main() { return 0; }", "lib.rs": "fn main() { println!('Hello'); }", - "config.toml": "[package]\nname = 'test'" + "config.toml": "[package]\nname = 'test'", } - + for filename, content in new_files.items(): filepath = os.path.join(docs_dir, filename) with open(filepath, "w") as f: @@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do # Check the database has been updated updated_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) assert updated_db._collection.count() > initial_count - + # Verify new files are included docs = updated_db._collection.get() sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} @@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs): """Test the entire RAG pipeline from document processing to querying""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) - + # Create documents with mixed content types test_files = { "python_info.py": """# Python Information def describe_python(): \"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\" pass""", - "readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation." + "readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.", } - + for filename, content in test_files.items(): filepath = os.path.join(rag_pipeline.docs_dir, filename) with open(filepath, "w") as f: @@ -160,7 +163,10 @@ def describe_python(): # Query the system query = "What is Python?" - response = rag_pipeline.query_rag(rag_chain, query) + try: + response = next(rag_pipeline.query_rag(rag_chain, query)) + except StopIteration as ex: + response = ex.value # Check if response contains relevant information # This is a soft test since the exact response will depend on the LLM