Fix test and porperly add conversational rag

This commit is contained in:
Alex Selimov 2025-03-23 00:08:17 -04:00
parent dda45e7155
commit 96801ba8e8
2 changed files with 98 additions and 55 deletions

View File

@ -1,14 +1,21 @@
import os
import glob
import uuid
from typing import List, Optional
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
import time
from .doc_tracker import DocumentTracker
from .ollama_wrapper import OllamaWrapper
@ -30,6 +37,7 @@ class RAG:
self.ollama = OllamaWrapper(
ollama_url, embedding_model=embedding_model, llm_model=llm_model
)
self.session_id = time.time()
def process_documents(self, files, text_splitter):
"""Process document files into chunks with tracking metadata"""
@ -177,41 +185,64 @@ class RAG:
# Create chat history buffer
self.chat_history = []
# Create RAG chain with chat history
system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context:
{context}
Chat History:
{chat_history}
Answer in a clear and concise manner. If you're referring to code, use markdown formatting."""
human_template = "{query}"
prompt = ChatPromptTemplate.from_messages(
### Contextualize question ###
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
("human", human_template),
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{
"context": lambda x: format_docs(retriever.invoke(x["query"])),
"chat_history": lambda x: "\n".join(self.chat_history),
"query": lambda x: x["query"],
}
| prompt
| self.ollama.llm
history_aware_retriever = create_history_aware_retriever(
self.ollama.llm, retriever, contextualize_q_prompt
)
return rag_chain
### Answer question ###
system_prompt = (
"You are an expert at analyzing code and documentation. "
"Use the following pieces of context to answer the question at the end."
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer"
"\n\n"
"{context}\n\n"
"Answer in a clear and concise manner. "
"If you're referring to code, use markdown formatting."
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt)
rag_chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
### Statefully manage chat history ###
self.store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
return RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
def query_rag(self, rag_chain, query, stream=False):
"""
@ -223,20 +254,26 @@ Answer in a clear and concise manner. If you're referring to code, use markdown
stream: If True, stream the response
"""
if stream:
response = rag_chain.stream({"query": query})
response = rag_chain.stream(
{"input": query},
config={
"configurable": {"session_id": self.session_id}
},
)
# Store in chat history after getting full response
full_response = ""
for chunk in response:
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
full_response += chunk_text
yield chunk_text
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {full_response}")
else:
response = rag_chain({"query": query})
response = rag_chain.invoke(
{"input": query},
config={
"configurable": {"session_id": self.session_id}
},
)
result = response.content if hasattr(response, "content") else str(response)
self.chat_history.append(f"Human: {query}")
self.chat_history.append(f"Assistant: {result}")
return result

View File

@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
"""Test creating a vector database"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create files with different extensions
files = {
"test.py": "def hello():\n print('Hello, World!')",
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
"config.toml": "[package]\nname = 'test'",
"doc.md": "# Documentation\nThis is a test file."
"doc.md": "# Documentation\nThis is a test file.",
}
for filename, content in files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
# Verify it was created
assert os.path.exists(rag_pipeline.db_dir)
assert vectorstore is not None
# Check the database has content from all file types
loaded_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
)
# Should have content from all files
assert loaded_db._collection.count() > 0
# Verify each file type is included
docs = loaded_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
"""Test updating a vector database with document changes"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create initial vector database with only Python files
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True)
# Get initial count
initial_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
)
initial_count = initial_db._collection.count()
@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
new_files = {
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
"lib.rs": "fn main() { println!('Hello'); }",
"config.toml": "[package]\nname = 'test'"
"config.toml": "[package]\nname = 'test'",
}
for filename, content in new_files.items():
filepath = os.path.join(docs_dir, filename)
with open(filepath, "w") as f:
@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
# Check the database has been updated
updated_db = Chroma(
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
persist_directory=rag_pipeline.db_dir,
embedding_function=rag_pipeline.ollama.embeddings,
)
assert updated_db._collection.count() > initial_count
# Verify new files are included
docs = updated_db._collection.get()
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
"""Test the entire RAG pipeline from document processing to querying"""
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
# Create documents with mixed content types
test_files = {
"python_info.py": """# Python Information
def describe_python():
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
pass""",
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.",
}
for filename, content in test_files.items():
filepath = os.path.join(rag_pipeline.docs_dir, filename)
with open(filepath, "w") as f:
@ -160,7 +163,10 @@ def describe_python():
# Query the system
query = "What is Python?"
response = rag_pipeline.query_rag(rag_chain, query)
try:
response = next(rag_pipeline.query_rag(rag_chain, query))
except StopIteration as ex:
response = ex.value
# Check if response contains relevant information
# This is a soft test since the exact response will depend on the LLM