From dda45e715594868eb66433400de80b60a8af508c Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Sat, 22 Mar 2025 22:16:14 -0400 Subject: [PATCH 1/3] Add a cli, update RAG to keep chat history and stream output --- pyproject.toml | 3 + src/code_rag/cli.py | 160 ++++++++++++++++++++++++++++++++++++++++++++ src/code_rag/rag.py | 130 +++++++++++++++++++---------------- 3 files changed, 235 insertions(+), 58 deletions(-) create mode 100644 src/code_rag/cli.py diff --git a/pyproject.toml b/pyproject.toml index f81eb07..59f79db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ ollama = "*" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" +[tool.poetry.scripts] +code-rag = "code_rag.cli:main" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/src/code_rag/cli.py b/src/code_rag/cli.py new file mode 100644 index 0000000..1355c06 --- /dev/null +++ b/src/code_rag/cli.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 + +import os +import sys +import argparse +from code_rag.rag import RAG + +def stream_output(response_iter): + """Stream the response to console, handling each token""" + try: + for chunk in response_iter: + print(chunk, end="", flush=True) + print() # New line at end + except KeyboardInterrupt: + print("\nStreaming interrupted by user") + +def interactive_chat(rag_pipeline, rag_chain, no_stream=False): + """Run an interactive chat session""" + print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n") + + try: + while True: + try: + query = input("\nQuestion: ").strip() + + if query.lower() in ['exit', 'quit', '']: + print("\nEnding chat session.") + break + + print("\nResponse:") + if no_stream: + response = rag_pipeline.query_rag(rag_chain, query, stream=False) + print(response) + else: + response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True) + stream_output(response_iter) + + except KeyboardInterrupt: + print("\n\nEnding chat session.") + break + except EOFError: + break + except KeyboardInterrupt: + print("\nChat session interrupted.") + +def main(): + parser = argparse.ArgumentParser( + description="Code RAG - Query your codebase using natural language" + ) + parser.add_argument( + "docs_dir", + help="Directory containing the documents to process", + type=str + ) + parser.add_argument( + "query", + help="Initial query about your codebase (optional in interactive mode)", + nargs="?", + default=None, + type=str + ) + parser.add_argument( + "--db-dir", + help="Directory to store the vector database (default: .code_rag_db)", + default=os.path.expanduser("~/.code_rag_db"), + type=str + ) + parser.add_argument( + "--tracker-file", + help="File to track document changes (default: .code_rag_tracker.json)", + default=os.path.expanduser("~/.code_rag_tracker.json"), + type=str + ) + parser.add_argument( + "--force-refresh", + help="Force refresh of the vector database", + action="store_true" + ) + parser.add_argument( + "--ollama-url", + help="URL for the Ollama server (default: 127.0.0.1)", + default="127.0.0.1", + type=str + ) + parser.add_argument( + "--embedding-model", + help="Model to use for embeddings (default: nomic-embed-text)", + default="nomic-embed-text", + type=str + ) + parser.add_argument( + "--llm-model", + help="Model to use for text generation (default: llama3.2)", + default="llama3.2", + type=str + ) + parser.add_argument( + "--no-stream", + help="Disable streaming output", + action="store_true" + ) + parser.add_argument( + "--no-interactive", + help="Run in non-interactive mode (answer single query and exit)", + action="store_true" + ) + + args = parser.parse_args() + + if args.no_interactive and not args.query: + parser.error("Query is required in non-interactive mode") + + # Create RAG pipeline + rag_pipeline = RAG( + docs_dir=args.docs_dir, + db_dir=args.db_dir, + tracker_file=args.tracker_file, + ollama_url=args.ollama_url, + embedding_model=args.embedding_model, + llm_model=args.llm_model + ) + + # Create or update vector database + print(f"Processing documents in {args.docs_dir}...") + print(f"Using models: embedding={args.embedding_model}, llm={args.llm_model}") + print(f"Ollama server: {args.ollama_url}") + vectorstore = rag_pipeline.create_vector_db(force_refresh=args.force_refresh) + + # Set up RAG chain + print("Setting up RAG pipeline...") + rag_chain = rag_pipeline.setup_rag() + + if args.no_interactive: + # Single query mode + print(f"\nQuery: {args.query}") + print("\nResponse:") + if args.no_stream: + response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) + print(response) + else: + response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) + stream_output(response_iter) + else: + # Interactive mode + if args.query: + # Handle initial query if provided + print(f"\nQuery: {args.query}") + print("\nResponse:") + if args.no_stream: + response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) + print(response) + else: + response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) + stream_output(response_iter) + + # Start interactive chat + interactive_chat(rag_pipeline, rag_chain, args.no_stream) + +if __name__ == "__main__": + main() diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index e1fd5f8..a4043a7 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -1,15 +1,17 @@ import os -import uuid import glob -from langchain_community.document_loaders import DirectoryLoader, TextLoader +import uuid +from typing import List, Optional +from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import TextLoader +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_core.messages import HumanMessage, SystemMessage from langchain_chroma import Chroma -from langchain.chains import RetrievalQA -from langchain.prompts import PromptTemplate -from langchain_core.documents import Document -from code_rag.doc_tracker import DocumentTracker -from code_rag.ollama_wrapper import OllamaWrapper +from .doc_tracker import DocumentTracker +from .ollama_wrapper import OllamaWrapper class RAG: @@ -161,55 +163,81 @@ class RAG: return vectorstore def setup_rag(self): - """ - Set up the RAG system with an existing vector database - """ - # Load the embeddings - embeddings = self.ollama.embeddings - - # Load the vector store + """Set up the RAG pipeline""" + # Create vector store vectorstore = Chroma( - persist_directory=self.db_dir, embedding_function=embeddings + persist_directory=self.db_dir, embedding_function=self.ollama.embeddings ) - # Create a retriever - retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) - - # Set up the LLM - llm = self.ollama.llm - - # Create a custom prompt template - template = """ - Answer the question based on the context provided. If you don't know the answer, - just say you don't know. Don't try to make up an answer. - - Context: {context} - - Question: {question} - - Answer: - """ - - prompt = PromptTemplate( - input_variables=["context", "question"], template=template + # Create retriever + retriever = vectorstore.as_retriever( + search_type="similarity", search_kwargs={"k": 4} ) - # Create the RAG chain - rag_chain = RetrievalQA.from_chain_type( - llm=llm, - chain_type="stuff", - retriever=retriever, - chain_type_kwargs={"prompt": prompt}, + # Create chat history buffer + self.chat_history = [] + + # Create RAG chain with chat history + system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end. +If you don't know the answer, just say that you don't know, don't try to make up an answer. + +Context: +{context} + +Chat History: +{chat_history} + +Answer in a clear and concise manner. If you're referring to code, use markdown formatting.""" + + human_template = "{query}" + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_template), + ("human", human_template), + ] + ) + + def format_docs(docs): + return "\n\n".join(doc.page_content for doc in docs) + + rag_chain = ( + { + "context": lambda x: format_docs(retriever.invoke(x["query"])), + "chat_history": lambda x: "\n".join(self.chat_history), + "query": lambda x: x["query"], + } + | prompt + | self.ollama.llm ) return rag_chain - def query_rag(self, rag_chain, query): + def query_rag(self, rag_chain, query, stream=False): """ Query the RAG system + + Args: + rag_chain: The RAG chain to use + query: Query string + stream: If True, stream the response """ - response = rag_chain.invoke({"query": query}) - return response["result"] + if stream: + response = rag_chain.stream({"query": query}) + # Store in chat history after getting full response + full_response = "" + for chunk in response: + chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) + full_response += chunk_text + yield chunk_text + self.chat_history.append(f"Human: {query}") + self.chat_history.append(f"Assistant: {full_response}") + else: + response = rag_chain({"query": query}) + result = response.content if hasattr(response, "content") else str(response) + self.chat_history.append(f"Human: {query}") + self.chat_history.append(f"Assistant: {result}") + return result def default_file_extensions(): @@ -217,11 +245,6 @@ def default_file_extensions(): return [ # Python ".py", - ".pyi", - ".pyx", - ".pyc", - ".pyd", - ".pyw", # C/C++ ".c", ".cpp", @@ -235,14 +258,8 @@ def default_file_extensions(): ".ipp", # Rust ".rs", - ".rlib", - ".rmeta", # Java ".java", - ".jsp", - ".jav", - ".jar", - ".class", ".kt", ".kts", ".groovy", @@ -250,9 +267,6 @@ def default_file_extensions(): ".html", ".htm", ".css", - ".scss", - ".sass", - ".less", ".js", ".jsx", ".ts", From 96801ba8e8425706ddf9a9c7d9d37d644919fc9f Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Sun, 23 Mar 2025 00:08:17 -0400 Subject: [PATCH 2/3] Fix test and porperly add conversational rag --- src/code_rag/rag.py | 115 +++++++++++++++++++++++++++++--------------- tests/test_rag.py | 38 +++++++++------ 2 files changed, 98 insertions(+), 55 deletions(-) diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index a4043a7..08c0ae2 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -1,14 +1,21 @@ import os import glob import uuid -from typing import List, Optional from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import TextLoader from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import RunnablePassthrough -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.chains import create_history_aware_retriever, create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_chroma import Chroma +import time from .doc_tracker import DocumentTracker from .ollama_wrapper import OllamaWrapper @@ -30,6 +37,7 @@ class RAG: self.ollama = OllamaWrapper( ollama_url, embedding_model=embedding_model, llm_model=llm_model ) + self.session_id = time.time() def process_documents(self, files, text_splitter): """Process document files into chunks with tracking metadata""" @@ -177,41 +185,64 @@ class RAG: # Create chat history buffer self.chat_history = [] - # Create RAG chain with chat history - system_template = """You are an expert at analyzing code and documentation. Use the following pieces of context to answer the question at the end. -If you don't know the answer, just say that you don't know, don't try to make up an answer. - -Context: -{context} - -Chat History: -{chat_history} - -Answer in a clear and concise manner. If you're referring to code, use markdown formatting.""" - - human_template = "{query}" - - prompt = ChatPromptTemplate.from_messages( + ### Contextualize question ### + contextualize_q_system_prompt = ( + "Given a chat history and the latest user question " + "which might reference context in the chat history, " + "formulate a standalone question which can be understood " + "without the chat history. Do NOT answer the question, " + "just reformulate it if needed and otherwise return it as is." + ) + contextualize_q_prompt = ChatPromptTemplate.from_messages( [ - ("system", system_template), - ("human", human_template), + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), ] ) - - def format_docs(docs): - return "\n\n".join(doc.page_content for doc in docs) - - rag_chain = ( - { - "context": lambda x: format_docs(retriever.invoke(x["query"])), - "chat_history": lambda x: "\n".join(self.chat_history), - "query": lambda x: x["query"], - } - | prompt - | self.ollama.llm + history_aware_retriever = create_history_aware_retriever( + self.ollama.llm, retriever, contextualize_q_prompt ) - return rag_chain + ### Answer question ### + system_prompt = ( + "You are an expert at analyzing code and documentation. " + "Use the following pieces of context to answer the question at the end." + "If you don't know the answer, just say that you don't know, " + "don't try to make up an answer" + "\n\n" + "{context}\n\n" + "Answer in a clear and concise manner. " + "If you're referring to code, use markdown formatting." + ) + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] + ) + question_answer_chain = create_stuff_documents_chain(self.ollama.llm, qa_prompt) + + rag_chain = create_retrieval_chain( + history_aware_retriever, question_answer_chain + ) + + ### Statefully manage chat history ### + self.store = {} + + def get_session_history(session_id: str) -> BaseChatMessageHistory: + if session_id not in self.store: + self.store[session_id] = ChatMessageHistory() + return self.store[session_id] + + return RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", + ) def query_rag(self, rag_chain, query, stream=False): """ @@ -223,20 +254,26 @@ Answer in a clear and concise manner. If you're referring to code, use markdown stream: If True, stream the response """ if stream: - response = rag_chain.stream({"query": query}) + response = rag_chain.stream( + {"input": query}, + config={ + "configurable": {"session_id": self.session_id} + }, + ) # Store in chat history after getting full response full_response = "" for chunk in response: chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) full_response += chunk_text yield chunk_text - self.chat_history.append(f"Human: {query}") - self.chat_history.append(f"Assistant: {full_response}") else: - response = rag_chain({"query": query}) + response = rag_chain.invoke( + {"input": query}, + config={ + "configurable": {"session_id": self.session_id} + }, + ) result = response.content if hasattr(response, "content") else str(response) - self.chat_history.append(f"Human: {query}") - self.chat_history.append(f"Assistant: {result}") return result diff --git a/tests/test_rag.py b/tests/test_rag.py index c0bef74..e7547af 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -51,16 +51,16 @@ def test_process_documents(tracker_file, docs_dir, db_dir, sample_docs, rag_pipe def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): """Test creating a vector database""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) - + # Create files with different extensions files = { "test.py": "def hello():\n print('Hello, World!')", "main.cpp": "#include \nint main() { std::cout << 'Hello'; return 0; }", "lib.rs": "fn main() { println!('Hello from Rust!'); }", "config.toml": "[package]\nname = 'test'", - "doc.md": "# Documentation\nThis is a test file." + "doc.md": "# Documentation\nThis is a test file.", } - + for filename, content in files.items(): filepath = os.path.join(docs_dir, filename) with open(filepath, "w") as f: @@ -72,14 +72,15 @@ def test_create_vector_db(docs_dir, db_dir, tracker_file, sample_docs): # Verify it was created assert os.path.exists(rag_pipeline.db_dir) assert vectorstore is not None - + # Check the database has content from all file types loaded_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) # Should have content from all files assert loaded_db._collection.count() > 0 - + # Verify each file type is included docs = loaded_db._collection.get() sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} @@ -94,11 +95,12 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do """Test updating a vector database with document changes""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) # Create initial vector database with only Python files - vectorstore = rag_pipeline.create_vector_db(extensions=['.py'], force_refresh=True) + vectorstore = rag_pipeline.create_vector_db(extensions=[".py"], force_refresh=True) # Get initial count initial_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) initial_count = initial_db._collection.count() @@ -107,9 +109,9 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do new_files = { "newdoc.cpp": "#include \nint main() { return 0; }", "lib.rs": "fn main() { println!('Hello'); }", - "config.toml": "[package]\nname = 'test'" + "config.toml": "[package]\nname = 'test'", } - + for filename, content in new_files.items(): filepath = os.path.join(docs_dir, filename) with open(filepath, "w") as f: @@ -120,10 +122,11 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do # Check the database has been updated updated_db = Chroma( - persist_directory=rag_pipeline.db_dir, embedding_function=rag_pipeline.ollama.embeddings + persist_directory=rag_pipeline.db_dir, + embedding_function=rag_pipeline.ollama.embeddings, ) assert updated_db._collection.count() > initial_count - + # Verify new files are included docs = updated_db._collection.get() sources = {os.path.basename(m["source"]) for m in docs["metadatas"]} @@ -137,16 +140,16 @@ def test_update_vector_db_with_changes(docs_dir, db_dir, tracker_file, sample_do def test_full_rag_pipeline(docs_dir, db_dir, tracker_file, sample_docs): """Test the entire RAG pipeline from document processing to querying""" rag_pipeline = RAG(docs_dir, db_dir, tracker_file) - + # Create documents with mixed content types test_files = { "python_info.py": """# Python Information def describe_python(): \"\"\"Python is a high-level programming language known for its readability and versatility.\"\"\" pass""", - "readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation." + "readme.md": "# Python\nPython is a popular programming language used in web development, data science, and automation.", } - + for filename, content in test_files.items(): filepath = os.path.join(rag_pipeline.docs_dir, filename) with open(filepath, "w") as f: @@ -160,7 +163,10 @@ def describe_python(): # Query the system query = "What is Python?" - response = rag_pipeline.query_rag(rag_chain, query) + try: + response = next(rag_pipeline.query_rag(rag_chain, query)) + except StopIteration as ex: + response = ex.value # Check if response contains relevant information # This is a soft test since the exact response will depend on the LLM From 121e815f5265de605802936a8e57c1feb6781d71 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Sun, 23 Mar 2025 00:23:35 -0400 Subject: [PATCH 3/3] Fix outputs to only include llm model answers --- src/code_rag/cli.py | 57 ++++++++++++++++++++++++--------------------- src/code_rag/rag.py | 29 +++++++++++++---------- 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/src/code_rag/cli.py b/src/code_rag/cli.py index 1355c06..4ec80ae 100644 --- a/src/code_rag/cli.py +++ b/src/code_rag/cli.py @@ -14,27 +14,32 @@ def stream_output(response_iter): except KeyboardInterrupt: print("\nStreaming interrupted by user") + def interactive_chat(rag_pipeline, rag_chain, no_stream=False): """Run an interactive chat session""" - print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n") - + print( + "\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n" + ) + try: while True: try: query = input("\nQuestion: ").strip() - - if query.lower() in ['exit', 'quit', '']: + + if query.lower() in ["exit", "quit", ""]: print("\nEnding chat session.") break - + print("\nResponse:") if no_stream: response = rag_pipeline.query_rag(rag_chain, query, stream=False) print(response) else: - response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True) + response_iter = rag_pipeline.query_rag( + rag_chain, query, stream=True + ) stream_output(response_iter) - + except KeyboardInterrupt: print("\n\nEnding chat session.") break @@ -43,66 +48,63 @@ def interactive_chat(rag_pipeline, rag_chain, no_stream=False): except KeyboardInterrupt: print("\nChat session interrupted.") + def main(): parser = argparse.ArgumentParser( description="Code RAG - Query your codebase using natural language" ) parser.add_argument( - "docs_dir", - help="Directory containing the documents to process", - type=str + "docs_dir", help="Directory containing the documents to process", type=str ) parser.add_argument( "query", help="Initial query about your codebase (optional in interactive mode)", nargs="?", default=None, - type=str + type=str, ) parser.add_argument( "--db-dir", help="Directory to store the vector database (default: .code_rag_db)", default=os.path.expanduser("~/.code_rag_db"), - type=str + type=str, ) parser.add_argument( "--tracker-file", help="File to track document changes (default: .code_rag_tracker.json)", default=os.path.expanduser("~/.code_rag_tracker.json"), - type=str + type=str, ) parser.add_argument( "--force-refresh", help="Force refresh of the vector database", - action="store_true" + action="store_true", ) parser.add_argument( "--ollama-url", help="URL for the Ollama server (default: 127.0.0.1)", default="127.0.0.1", - type=str + type=str, ) parser.add_argument( "--embedding-model", help="Model to use for embeddings (default: nomic-embed-text)", default="nomic-embed-text", - type=str + type=str, ) parser.add_argument( "--llm-model", help="Model to use for text generation (default: llama3.2)", default="llama3.2", - type=str + type=str, ) parser.add_argument( - "--no-stream", - help="Disable streaming output", - action="store_true" + "--no-stream", help="Disable streaming output", action="store_true" ) parser.add_argument( "--no-interactive", help="Run in non-interactive mode (answer single query and exit)", - action="store_true" + action="store_true", ) args = parser.parse_args() @@ -117,7 +119,7 @@ def main(): tracker_file=args.tracker_file, ollama_url=args.ollama_url, embedding_model=args.embedding_model, - llm_model=args.llm_model + llm_model=args.llm_model, ) # Create or update vector database @@ -136,7 +138,7 @@ def main(): print("\nResponse:") if args.no_stream: response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) - print(response) + stream_output(response) else: response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) stream_output(response_iter) @@ -148,13 +150,16 @@ def main(): print("\nResponse:") if args.no_stream: response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) - print(response) + stream_output(response) else: - response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) + response_iter = rag_pipeline.query_rag( + rag_chain, args.query, stream=True + ) stream_output(response_iter) - + # Start interactive chat interactive_chat(rag_pipeline, rag_chain, args.no_stream) + if __name__ == "__main__": main() diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index 08c0ae2..9d68670 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -94,9 +94,7 @@ class RAG: ) # Create embeddings - print("Before embedding") embeddings = self.ollama.embeddings - print("after embedding") # Load or create vector store if os.path.exists(self.db_dir) and not force_refresh: @@ -256,24 +254,31 @@ class RAG: if stream: response = rag_chain.stream( {"input": query}, - config={ - "configurable": {"session_id": self.session_id} - }, + config={"configurable": {"session_id": self.session_id}}, ) # Store in chat history after getting full response full_response = "" for chunk in response: - chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) - full_response += chunk_text - yield chunk_text + # Extract only the LLM answer chunk + if "answer" in chunk: + chunk_text = ( + chunk["answer"].content + if hasattr(chunk["answer"], "content") + else str(chunk["answer"]) + ) + full_response += chunk_text + yield chunk_text else: response = rag_chain.invoke( {"input": query}, - config={ - "configurable": {"session_id": self.session_id} - }, + config={"configurable": {"session_id": self.session_id}}, + ) + # Extract only the LLM answer from the response + result = ( + response["answer"].content + if hasattr(response["answer"], "content") + else str(response["answer"]) ) - result = response.content if hasattr(response, "content") else str(response) return result