From 121e815f5265de605802936a8e57c1feb6781d71 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Sun, 23 Mar 2025 00:23:35 -0400 Subject: [PATCH] Fix outputs to only include llm model answers --- src/code_rag/cli.py | 57 ++++++++++++++++++++++++--------------------- src/code_rag/rag.py | 29 +++++++++++++---------- 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/src/code_rag/cli.py b/src/code_rag/cli.py index 1355c06..4ec80ae 100644 --- a/src/code_rag/cli.py +++ b/src/code_rag/cli.py @@ -14,27 +14,32 @@ def stream_output(response_iter): except KeyboardInterrupt: print("\nStreaming interrupted by user") + def interactive_chat(rag_pipeline, rag_chain, no_stream=False): """Run an interactive chat session""" - print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n") - + print( + "\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n" + ) + try: while True: try: query = input("\nQuestion: ").strip() - - if query.lower() in ['exit', 'quit', '']: + + if query.lower() in ["exit", "quit", ""]: print("\nEnding chat session.") break - + print("\nResponse:") if no_stream: response = rag_pipeline.query_rag(rag_chain, query, stream=False) print(response) else: - response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True) + response_iter = rag_pipeline.query_rag( + rag_chain, query, stream=True + ) stream_output(response_iter) - + except KeyboardInterrupt: print("\n\nEnding chat session.") break @@ -43,66 +48,63 @@ def interactive_chat(rag_pipeline, rag_chain, no_stream=False): except KeyboardInterrupt: print("\nChat session interrupted.") + def main(): parser = argparse.ArgumentParser( description="Code RAG - Query your codebase using natural language" ) parser.add_argument( - "docs_dir", - help="Directory containing the documents to process", - type=str + "docs_dir", help="Directory containing the documents to process", type=str ) parser.add_argument( "query", help="Initial query about your codebase (optional in interactive mode)", nargs="?", default=None, - type=str + type=str, ) parser.add_argument( "--db-dir", help="Directory to store the vector database (default: .code_rag_db)", default=os.path.expanduser("~/.code_rag_db"), - type=str + type=str, ) parser.add_argument( "--tracker-file", help="File to track document changes (default: .code_rag_tracker.json)", default=os.path.expanduser("~/.code_rag_tracker.json"), - type=str + type=str, ) parser.add_argument( "--force-refresh", help="Force refresh of the vector database", - action="store_true" + action="store_true", ) parser.add_argument( "--ollama-url", help="URL for the Ollama server (default: 127.0.0.1)", default="127.0.0.1", - type=str + type=str, ) parser.add_argument( "--embedding-model", help="Model to use for embeddings (default: nomic-embed-text)", default="nomic-embed-text", - type=str + type=str, ) parser.add_argument( "--llm-model", help="Model to use for text generation (default: llama3.2)", default="llama3.2", - type=str + type=str, ) parser.add_argument( - "--no-stream", - help="Disable streaming output", - action="store_true" + "--no-stream", help="Disable streaming output", action="store_true" ) parser.add_argument( "--no-interactive", help="Run in non-interactive mode (answer single query and exit)", - action="store_true" + action="store_true", ) args = parser.parse_args() @@ -117,7 +119,7 @@ def main(): tracker_file=args.tracker_file, ollama_url=args.ollama_url, embedding_model=args.embedding_model, - llm_model=args.llm_model + llm_model=args.llm_model, ) # Create or update vector database @@ -136,7 +138,7 @@ def main(): print("\nResponse:") if args.no_stream: response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) - print(response) + stream_output(response) else: response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) stream_output(response_iter) @@ -148,13 +150,16 @@ def main(): print("\nResponse:") if args.no_stream: response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) - print(response) + stream_output(response) else: - response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) + response_iter = rag_pipeline.query_rag( + rag_chain, args.query, stream=True + ) stream_output(response_iter) - + # Start interactive chat interactive_chat(rag_pipeline, rag_chain, args.no_stream) + if __name__ == "__main__": main() diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index 08c0ae2..9d68670 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -94,9 +94,7 @@ class RAG: ) # Create embeddings - print("Before embedding") embeddings = self.ollama.embeddings - print("after embedding") # Load or create vector store if os.path.exists(self.db_dir) and not force_refresh: @@ -256,24 +254,31 @@ class RAG: if stream: response = rag_chain.stream( {"input": query}, - config={ - "configurable": {"session_id": self.session_id} - }, + config={"configurable": {"session_id": self.session_id}}, ) # Store in chat history after getting full response full_response = "" for chunk in response: - chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) - full_response += chunk_text - yield chunk_text + # Extract only the LLM answer chunk + if "answer" in chunk: + chunk_text = ( + chunk["answer"].content + if hasattr(chunk["answer"], "content") + else str(chunk["answer"]) + ) + full_response += chunk_text + yield chunk_text else: response = rag_chain.invoke( {"input": query}, - config={ - "configurable": {"session_id": self.session_id} - }, + config={"configurable": {"session_id": self.session_id}}, + ) + # Extract only the LLM answer from the response + result = ( + response["answer"].content + if hasattr(response["answer"], "content") + else str(response["answer"]) ) - result = response.content if hasattr(response, "content") else str(response) return result