From 96801ba8e8425706ddf9a9c7d9d37d644919fc9f Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Sun, 23 Mar 2025 00:08:17 -0400 Subject: [PATCH] Fix test and porperly add conversational rag --- src/code_rag/rag.py | 115 +++++++++++++++++++++++++++++--------------- tests/test_rag.py | 38 +++++++++------ 2 files changed, 98 insertions(+), 55 deletions(-) diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index a4043a7..08c0ae2 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -1,14 +1,21 @@ import os import glob import uuid -from typing import List, Optional from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import TextLoader from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import RunnablePassthrough -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.chains import create_history_aware_retriever, create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_chroma import Chroma +import time from .doc_tracker import DocumentTracker from .ollama_wrapper import OllamaWrapper @@ -30,6 +37,7 @@ class RAG: self.ollama = OllamaWrapper( ollama_url, embedding_model=embedding_model, llm_model=llm_model ) + self.session_id = time.time() def process_documents(self, files, text_splitter): """Process document files into chunks with tracking metadata""" @@ -177,41 +185,64 @@ class RAG: # Create chat history buffer self.chat_history = [] - # Create RAG chain with chat history - system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end. -If you don't know the answer, just say that you don't know, don't try to make up an answer. - -Context: -{context} - -Chat History: -{chat_history} - -Answer in a clear and concise manner. If you're referring to code, use markdown formatting.""" - - human_template = "{query}" - - prompt = ChatPromptTemplate.from_messages( + ### Contextualize question ### + contextualize_q_system_prompt = ( + "Given a chat history and the latest user question " + "which might reference context in the chat history, " + "formulate a standalone question which can be understood " + "without the chat history. Do NOT answer the question, " + "just reformulate it if needed and otherwise return it as is." + ) + contextualize_q_prompt = ChatPromptTemplate.from_messages( [ - ("system", system_template), - ("human", human_template), + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), ] ) - - def format_docs(docs): - return "\n\n".join(doc.page_content for doc in docs) - - rag_chain = ( - { - "context": lambda x: format_docs(retriever.invoke(x["query"])), - "chat_history": lambda x: "\n".join(self.chat_history), - "query": lambda x: x["query"], - } - | prompt - | self.ollama.llm + history_aware_retriever = create_history_aware_retriever( + self.ollama.llm, retriever, contextualize_q_prompt ) - return rag_chain + ### Answer question ### + system_prompt = ( + "You are an expert at analyzing code and documentation. " + "Use the following pieces of context to answer the question at the end." + "If you don't know the answer, just say that you don't know, " + "don't try to make up an answer" + "\n\n" + "{context}\n\n" + "Answer in a clear and concise manner. " + "If you're referring to code, use markdown formatting." + ) + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt) + + rag_chain = create_retrieval_chain( + history_aware_retriever, question_answer_chain + ) + + ### Statefully manage chat history ### + self.store = {} + + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in self.store: + self.store[session_id] = ChatMessageHistory() + return self.store[session_id] + + return RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", + ) def query_rag(self, rag_chain, query, stream=False): """ @@ -223,20 +254,26 @@ Answer in a clear and concise manner. If you're referring to code, use markdown stream: If True, stream the response """ if stream: - response = rag_chain.stream({"query": query}) + response = rag_chain.stream( + {"input": query}, + config={ + "configurable": {"session_id": self.session_id} + }, + ) # Store in chat history after getting full response full_response = "" for chunk in response: chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) full_response += chunk_text yield chunk_text - self.chat_history.append(f"Human: {query}") - self.chat_history.append(f"Assistant: {full_response}") else: - response = rag_chain({"query": query}) + response = rag_chain.invoke( + {"input": query}, + config={ + "configurable": {"session_id": self.session_id} + }, + ) result = response.content if hasattr(response, "content") else str(response) - self.chat_history.append(f"Human: {query}") - self.chat_history.append(f"Assistant: {result}") return result diff --git a/tests/test_rag.py b/tests/test_rag.py index c0bef74..e7547af 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): """Test creating a vector database""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) - + # Create files with different extensions files = { "test.py": "def hello():\n print('Hello, World!')", "main.cpp": "#include \nint main() { std::cout << 'Hello'; return 0; }", "lib.rs": "fn main() { println!('Hello from Rust!'); }", "config.toml": "[package]\nname = 'test'", - "doc.md": "# Documentation\nThis is a test file." + "doc.md": "# Documentation\nThis is a test file.", } - + for filename, content in files.items(): filepath = os.path.join(docs_dir, filename) with open(filepath, "w") as f: @@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): # Verify it was created assert os.path.exists(rag_pipeline.db_dir) assert vectorstore is not None - + # Check the database has content from all file types loaded_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) # Should have content from all files assert loaded_db._collection.count() > 0 - + # Verify each file type is included docs = loaded_db._collection.get() sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} @@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do """Test updating a vector database with document changes""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) # Create initial vector database with only Python files - vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True) + vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True) # Get initial count initial_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) initial_count = initial_db._collection.count() @@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do new_files = { "newdoc.cpp": "#include \nint main() { return 0; }", "lib.rs": "fn main() { println!('Hello'); }", - "config.toml": "[package]\nname = 'test'" + "config.toml": "[package]\nname = 'test'", } - + for filename, content in new_files.items(): filepath = os.path.join(docs_dir, filename) with open(filepath, "w") as f: @@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do # Check the database has been updated updated_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) assert updated_db._collection.count() > initial_count - + # Verify new files are included docs = updated_db._collection.get() sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} @@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs): """Test the entire RAG pipeline from document processing to querying""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) - + # Create documents with mixed content types test_files = { "python_info.py": """# Python Information def describe_python(): \"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\" pass""", - "readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation." + "readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.", } - + for filename, content in test_files.items(): filepath = os.path.join(rag_pipeline.docs_dir, filename) with open(filepath, "w") as f: @@ -160,7 +163,10 @@ def describe_python(): # Query the system query = "What is Python?" - response = rag_pipeline.query_rag(rag_chain, query) + try: + response = next(rag_pipeline.query_rag(rag_chain, query)) + except StopIteration as ex: + response = ex.value # Check if response contains relevant information # This is a soft test since the exact response will depend on the LLM