Fix test and porperly add conversational rag

This commit is contained in:
Alex Selimov 2025-03-23 00:08:17 -04:00
parent dda45e7155
commit 96801ba8e8
2 changed files with 98 additions and 55 deletions

View File

@ -1,14 +1,21 @@
import os import os
import glob import glob
import uuid import uuid
from typing import List, Optional
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader from langchain_community.document_loaders import TextLoader
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma from langchain_chroma import Chroma
import time
from .doc_tracker import DocumentTracker from .doc_tracker import DocumentTracker
from .ollama_wrapper import OllamaWrapper from .ollama_wrapper import OllamaWrapper
@ -30,6 +37,7 @@ class RAG:
self.ollama = OllamaWrapper( self.ollama = OllamaWrapper(
ollama_url, embedding_model=embedding_model, llm_model=llm_model ollama_url, embedding_model=embedding_model, llm_model=llm_model
) )
self.session_id = time.time()
def process_documents(self, files, text_splitter): def process_documents(self, files, text_splitter):
"""Process document files into chunks with tracking metadata""" """Process document files into chunks with tracking metadata"""
@ -177,41 +185,64 @@ class RAG:
# Create chat history buffer # Create chat history buffer
self.chat_history = [] self.chat_history = []
# Create RAG chain with chat history ### Contextualize question ###
system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end. contextualize_q_system_prompt = (
If you don't know the answer, just say that you don't know, don't try to make up an answer. "Given a chat history and the latest user question "
"which might reference context in the chat history, "
Context: "formulate a standalone question which can be understood "
{context} "without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
Chat History: )
{chat_history} contextualize_q_prompt = ChatPromptTemplate.from_messages(
Answer in a clear and concise manner. If you're referring to code, use markdown formatting."""
human_template = "{query}"
prompt = ChatPromptTemplate.from_messages(
[ [
("system", system_template), ("system", contextualize_q_system_prompt),
("human", human_template), MessagesPlaceholder("chat_history"),
("human", "{input}"),
] ]
) )
history_aware_retriever = create_history_aware_retriever(
def format_docs(docs): self.ollama.llm, retriever, contextualize_q_prompt
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{
"context": lambda x: format_docs(retriever.invoke(x["query"])),
"chat_history": lambda x: "\n".join(self.chat_history),
"query": lambda x: x["query"],
}
| prompt
| self.ollama.llm
) )
return rag_chain ### Answer question ###
system_prompt = (
"You are an expert at analyzing code and documentation. "
"Use the following pieces of context to answer the question at the end."
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer"
"\n\n"
"{context}\n\n"
"Answer in a clear and concise manner. "
"If you're referring to code, use markdown formatting."
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt)
rag_chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
### Statefully manage chat history ###
self.store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
return RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
def query_rag(self, rag_chain, query, stream=False): def query_rag(self, rag_chain, query, stream=False):
""" """
@ -223,20 +254,26 @@ Answer in a clear and concise manner. If you're referring to code, use markdown
stream: If True, stream the response stream: If True, stream the response
""" """
if stream: if stream:
response = rag_chain.stream({"query": query}) response = rag_chain.stream(
{"input": query},
config={
"configurable": {"session_id": self.session_id}
},
)
# Store in chat history after getting full response # Store in chat history after getting full response
full_response = "" full_response = ""
for chunk in response: for chunk in response:
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
full_response += chunk_text full_response += chunk_text
yield chunk_text yield chunk_text
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {full_response}")
else: else:
response = rag_chain({"query": query}) response = rag_chain.invoke(
{"input": query},
config={
"configurable": {"session_id": self.session_id}
},
)
result = response.content if hasattr(response, "content") else str(response) result = response.content if hasattr(response, "content") else str(response)
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {result}")
return result return result

View File

@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
"""Test creating a vector database""" """Test creating a vector database"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file) rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create files with different extensions # Create files with different extensions
files = { files = {
"test.py": "def hello():\n print('Hello, World!')", "test.py": "def hello():\n print('Hello, World!')",
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }", "main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
"lib.rs": "fn main() { println!('Hello from Rust!'); }", "lib.rs": "fn main() { println!('Hello from Rust!'); }",
"config.toml": "[package]\nname = 'test'", "config.toml": "[package]\nname = 'test'",
"doc.md": "# Documentation\nThis is a test file." "doc.md": "# Documentation\nThis is a test file.",
} }
for filename, content in files.items(): for filename, content in files.items():
filepath = os.path.join(docs_dir, filename) filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f: with open(filepath, "w") as f:
@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
# Verify it was created # Verify it was created
assert os.path.exists(rag_pipeline.db_dir) assert os.path.exists(rag_pipeline.db_dir)
assert vectorstore is not None assert vectorstore is not None
# Check the database has content from all file types # Check the database has content from all file types
loaded_db = Chroma( loaded_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
) )
# Should have content from all files # Should have content from all files
assert loaded_db._collection.count() > 0 assert loaded_db._collection.count() > 0
# Verify each file type is included # Verify each file type is included
docs = loaded_db._collection.get() docs = loaded_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
"""Test updating a vector database with document changes""" """Test updating a vector database with document changes"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file) rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database with only Python files # Create initial vector database with only Python files
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True) vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True)
# Get initial count # Get initial count
initial_db = Chroma( initial_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
) )
initial_count = initial_db._collection.count() initial_count = initial_db._collection.count()
@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
new_files = { new_files = {
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }", "newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
"lib.rs": "fn main() { println!('Hello'); }", "lib.rs": "fn main() { println!('Hello'); }",
"config.toml": "[package]\nname = 'test'" "config.toml": "[package]\nname = 'test'",
} }
for filename, content in new_files.items(): for filename, content in new_files.items():
filepath = os.path.join(docs_dir, filename) filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f: with open(filepath, "w") as f:
@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
# Check the database has been updated # Check the database has been updated
updated_db = Chroma( updated_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
) )
assert updated_db._collection.count() > initial_count assert updated_db._collection.count() > initial_count
# Verify new files are included # Verify new files are included
docs = updated_db._collection.get() docs = updated_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs): def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
"""Test the entire RAG pipeline from document processing to querying""" """Test the entire RAG pipeline from document processing to querying"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file) rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create documents with mixed content types # Create documents with mixed content types
test_files = { test_files = {
"python_info.py": """# Python Information "python_info.py": """# Python Information
def describe_python(): def describe_python():
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\" \"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
pass""", pass""",
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation." "readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.",
} }
for filename, content in test_files.items(): for filename, content in test_files.items():
filepath = os.path.join(rag_pipeline.docs_dir, filename) filepath = os.path.join(rag_pipeline.docs_dir, filename)
with open(filepath, "w") as f: with open(filepath, "w") as f:
@ -160,7 +163,10 @@ def describe_python():
# Query the system # Query the system
query = "What is Python?" query = "What is Python?"
response = rag_pipeline.query_rag(rag_chain, query) try:
response = next(rag_pipeline.query_rag(rag_chain, query))
except StopIteration as ex:
response = ex.value
# Check if response contains relevant information # Check if response contains relevant information
# This is a soft test since the exact response will depend on the LLM # This is a soft test since the exact response will depend on the LLM