Fix test and porperly add conversational rag
This commit is contained in:
parent
dda45e7155
commit
96801ba8e8
@ -1,14 +1,21 @@
|
||||
import os
|
||||
import glob
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
from langchain_community.chat_message_histories import ChatMessageHistory
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_chroma import Chroma
|
||||
import time
|
||||
|
||||
from .doc_tracker import DocumentTracker
|
||||
from .ollama_wrapper import OllamaWrapper
|
||||
@ -30,6 +37,7 @@ class RAG:
|
||||
self.ollama = OllamaWrapper(
|
||||
ollama_url, embedding_model=embedding_model, llm_model=llm_model
|
||||
)
|
||||
self.session_id = time.time()
|
||||
|
||||
def process_documents(self, files, text_splitter):
|
||||
"""Process document files into chunks with tracking metadata"""
|
||||
@ -177,41 +185,64 @@ class RAG:
|
||||
# Create chat history buffer
|
||||
self.chat_history = []
|
||||
|
||||
# Create RAG chain with chat history
|
||||
system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Answer in a clear and concise manner. If you're referring to code, use markdown formatting."""
|
||||
|
||||
human_template = "{query}"
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
### Contextualize question ###
|
||||
contextualize_q_system_prompt = (
|
||||
"Given a chat history and the latest user question "
|
||||
"which might reference context in the chat history, "
|
||||
"formulate a standalone question which can be understood "
|
||||
"without the chat history. Do NOT answer the question, "
|
||||
"just reformulate it if needed and otherwise return it as is."
|
||||
)
|
||||
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_template),
|
||||
("human", human_template),
|
||||
("system", contextualize_q_system_prompt),
|
||||
MessagesPlaceholder("chat_history"),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
def format_docs(docs):
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
rag_chain = (
|
||||
{
|
||||
"context": lambda x: format_docs(retriever.invoke(x["query"])),
|
||||
"chat_history": lambda x: "\n".join(self.chat_history),
|
||||
"query": lambda x: x["query"],
|
||||
}
|
||||
| prompt
|
||||
| self.ollama.llm
|
||||
history_aware_retriever = create_history_aware_retriever(
|
||||
self.ollama.llm, retriever, contextualize_q_prompt
|
||||
)
|
||||
|
||||
return rag_chain
|
||||
### Answer question ###
|
||||
system_prompt = (
|
||||
"You are an expert at analyzing code and documentation. "
|
||||
"Use the following pieces of context to answer the question at the end."
|
||||
"If you don't know the answer, just say that you don't know, "
|
||||
"don't try to make up an answer"
|
||||
"\n\n"
|
||||
"{context}\n\n"
|
||||
"Answer in a clear and concise manner. "
|
||||
"If you're referring to code, use markdown formatting."
|
||||
)
|
||||
qa_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_prompt),
|
||||
MessagesPlaceholder("chat_history"),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt)
|
||||
|
||||
rag_chain = create_retrieval_chain(
|
||||
history_aware_retriever, question_answer_chain
|
||||
)
|
||||
|
||||
### Statefully manage chat history ###
|
||||
self.store = {}
|
||||
|
||||
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
||||
if session_id not in self.store:
|
||||
self.store[session_id] = ChatMessageHistory()
|
||||
return self.store[session_id]
|
||||
|
||||
return RunnableWithMessageHistory(
|
||||
rag_chain,
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history",
|
||||
output_messages_key="answer",
|
||||
)
|
||||
|
||||
def query_rag(self, rag_chain, query, stream=False):
|
||||
"""
|
||||
@ -223,20 +254,26 @@ Answer in a clear and concise manner. If you're referring to code, use markdown
|
||||
stream: If True, stream the response
|
||||
"""
|
||||
if stream:
|
||||
response = rag_chain.stream({"query": query})
|
||||
response = rag_chain.stream(
|
||||
{"input": query},
|
||||
config={
|
||||
"configurable": {"session_id": self.session_id}
|
||||
},
|
||||
)
|
||||
# Store in chat history after getting full response
|
||||
full_response = ""
|
||||
for chunk in response:
|
||||
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
|
||||
full_response += chunk_text
|
||||
yield chunk_text
|
||||
self.chat_history.append(f"Human: {query}")
|
||||
self.chat_history.append(f"Assistant: {full_response}")
|
||||
else:
|
||||
response = rag_chain({"query": query})
|
||||
response = rag_chain.invoke(
|
||||
{"input": query},
|
||||
config={
|
||||
"configurable": {"session_id": self.session_id}
|
||||
},
|
||||
)
|
||||
result = response.content if hasattr(response, "content") else str(response)
|
||||
self.chat_history.append(f"Human: {query}")
|
||||
self.chat_history.append(f"Assistant: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe
|
||||
def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test creating a vector database"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
|
||||
|
||||
# Create files with different extensions
|
||||
files = {
|
||||
"test.py": "def hello():\n print('Hello, World!')",
|
||||
"main.cpp": "#include <iostream>\nint main() { std::cout << 'Hello'; return 0; }",
|
||||
"lib.rs": "fn main() { println!('Hello from Rust!'); }",
|
||||
"config.toml": "[package]\nname = 'test'",
|
||||
"doc.md": "# Documentation\nThis is a test file."
|
||||
"doc.md": "# Documentation\nThis is a test file.",
|
||||
}
|
||||
|
||||
|
||||
for filename, content in files.items():
|
||||
filepath = os.path.join(docs_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
# Verify it was created
|
||||
assert os.path.exists(rag_pipeline.db_dir)
|
||||
assert vectorstore is not None
|
||||
|
||||
|
||||
# Check the database has content from all file types
|
||||
loaded_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||
persist_directory=rag_pipeline.db_dir,
|
||||
embedding_function=rag_pipeline.ollama.embeddings,
|
||||
)
|
||||
# Should have content from all files
|
||||
assert loaded_db._collection.count() > 0
|
||||
|
||||
|
||||
# Verify each file type is included
|
||||
docs = loaded_db._collection.get()
|
||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||
@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
||||
"""Test updating a vector database with document changes"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
# Create initial vector database with only Python files
|
||||
vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True)
|
||||
vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True)
|
||||
|
||||
# Get initial count
|
||||
initial_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||
persist_directory=rag_pipeline.db_dir,
|
||||
embedding_function=rag_pipeline.ollama.embeddings,
|
||||
)
|
||||
initial_count = initial_db._collection.count()
|
||||
|
||||
@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
||||
new_files = {
|
||||
"newdoc.cpp": "#include <iostream>\nint main() { return 0; }",
|
||||
"lib.rs": "fn main() { println!('Hello'); }",
|
||||
"config.toml": "[package]\nname = 'test'"
|
||||
"config.toml": "[package]\nname = 'test'",
|
||||
}
|
||||
|
||||
|
||||
for filename, content in new_files.items():
|
||||
filepath = os.path.join(docs_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
||||
|
||||
# Check the database has been updated
|
||||
updated_db = Chroma(
|
||||
persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings
|
||||
persist_directory=rag_pipeline.db_dir,
|
||||
embedding_function=rag_pipeline.ollama.embeddings,
|
||||
)
|
||||
assert updated_db._collection.count() > initial_count
|
||||
|
||||
|
||||
# Verify new files are included
|
||||
docs = updated_db._collection.get()
|
||||
sources = {os.path.basename(m["source"]) for m in docs["metadatas"]}
|
||||
@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do
|
||||
def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs):
|
||||
"""Test the entire RAG pipeline from document processing to querying"""
|
||||
rag_pipeline = RAG(docs_dir, db_dir, tracker_file)
|
||||
|
||||
|
||||
# Create documents with mixed content types
|
||||
test_files = {
|
||||
"python_info.py": """# Python Information
|
||||
def describe_python():
|
||||
\"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\"
|
||||
pass""",
|
||||
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation."
|
||||
"readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.",
|
||||
}
|
||||
|
||||
|
||||
for filename, content in test_files.items():
|
||||
filepath = os.path.join(rag_pipeline.docs_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@ -160,7 +163,10 @@ def describe_python():
|
||||
|
||||
# Query the system
|
||||
query = "What is Python?"
|
||||
response = rag_pipeline.query_rag(rag_chain, query)
|
||||
try:
|
||||
response = next(rag_pipeline.query_rag(rag_chain, query))
|
||||
except StopIteration as ex:
|
||||
response = ex.value
|
||||
|
||||
# Check if response contains relevant information
|
||||
# This is a soft test since the exact response will depend on the LLM
|
||||
|
Loading…
x
Reference in New Issue
Block a user