Fix outputs to only include llm model answers
All checks were successful
Pytest / Explore-Gitea-Actions (push) Successful in -13m41s

This commit is contained in:
Alex Selimov 2025-03-23 00:23:35 -04:00
parent 96801ba8e8
commit 121e815f52
2 changed files with 48 additions and 38 deletions

View File

@ -14,27 +14,32 @@ def stream_output(response_iter):
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nStreaming interrupted by user") print("\nStreaming interrupted by user")
def interactive_chat(rag_pipeline, rag_chain, no_stream=False): def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
"""Run an interactive chat session""" """Run an interactive chat session"""
print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n") print(
"\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n"
)
try: try:
while True: while True:
try: try:
query = input("\nQuestion: ").strip() query = input("\nQuestion: ").strip()
if query.lower() in ['exit', 'quit', '']: if query.lower() in ["exit", "quit", ""]:
print("\nEnding chat session.") print("\nEnding chat session.")
break break
print("\nResponse:") print("\nResponse:")
if no_stream: if no_stream:
response = rag_pipeline.query_rag(rag_chain, query, stream=False) response = rag_pipeline.query_rag(rag_chain, query, stream=False)
print(response) print(response)
else: else:
response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True) response_iter = rag_pipeline.query_rag(
rag_chain, query, stream=True
)
stream_output(response_iter) stream_output(response_iter)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\nEnding chat session.") print("\n\nEnding chat session.")
break break
@ -43,66 +48,63 @@ def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nChat session interrupted.") print("\nChat session interrupted.")
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Code RAG - Query your codebase using natural language" description="Code RAG - Query your codebase using natural language"
) )
parser.add_argument( parser.add_argument(
"docs_dir", "docs_dir", help="Directory containing the documents to process", type=str
help="Directory containing the documents to process",
type=str
) )
parser.add_argument( parser.add_argument(
"query", "query",
help="Initial query about your codebase (optional in interactive mode)", help="Initial query about your codebase (optional in interactive mode)",
nargs="?", nargs="?",
default=None, default=None,
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--db-dir", "--db-dir",
help="Directory to store the vector database (default: .code_rag_db)", help="Directory to store the vector database (default: .code_rag_db)",
default=os.path.expanduser("~/.code_rag_db"), default=os.path.expanduser("~/.code_rag_db"),
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--tracker-file", "--tracker-file",
help="File to track document changes (default: .code_rag_tracker.json)", help="File to track document changes (default: .code_rag_tracker.json)",
default=os.path.expanduser("~/.code_rag_tracker.json"), default=os.path.expanduser("~/.code_rag_tracker.json"),
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--force-refresh", "--force-refresh",
help="Force refresh of the vector database", help="Force refresh of the vector database",
action="store_true" action="store_true",
) )
parser.add_argument( parser.add_argument(
"--ollama-url", "--ollama-url",
help="URL for the Ollama server (default: 127.0.0.1)", help="URL for the Ollama server (default: 127.0.0.1)",
default="127.0.0.1", default="127.0.0.1",
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--embedding-model", "--embedding-model",
help="Model to use for embeddings (default: nomic-embed-text)", help="Model to use for embeddings (default: nomic-embed-text)",
default="nomic-embed-text", default="nomic-embed-text",
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--llm-model", "--llm-model",
help="Model to use for text generation (default: llama3.2)", help="Model to use for text generation (default: llama3.2)",
default="llama3.2", default="llama3.2",
type=str type=str,
) )
parser.add_argument( parser.add_argument(
"--no-stream", "--no-stream", help="Disable streaming output", action="store_true"
help="Disable streaming output",
action="store_true"
) )
parser.add_argument( parser.add_argument(
"--no-interactive", "--no-interactive",
help="Run in non-interactive mode (answer single query and exit)", help="Run in non-interactive mode (answer single query and exit)",
action="store_true" action="store_true",
) )
args = parser.parse_args() args = parser.parse_args()
@ -117,7 +119,7 @@ def main():
tracker_file=args.tracker_file, tracker_file=args.tracker_file,
ollama_url=args.ollama_url, ollama_url=args.ollama_url,
embedding_model=args.embedding_model, embedding_model=args.embedding_model,
llm_model=args.llm_model llm_model=args.llm_model,
) )
# Create or update vector database # Create or update vector database
@ -136,7 +138,7 @@ def main():
print("\nResponse:") print("\nResponse:")
if args.no_stream: if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response) stream_output(response)
else: else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
stream_output(response_iter) stream_output(response_iter)
@ -148,13 +150,16 @@ def main():
print("\nResponse:") print("\nResponse:")
if args.no_stream: if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False) response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response) stream_output(response)
else: else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True) response_iter = rag_pipeline.query_rag(
rag_chain, args.query, stream=True
)
stream_output(response_iter) stream_output(response_iter)
# Start interactive chat # Start interactive chat
interactive_chat(rag_pipeline, rag_chain, args.no_stream) interactive_chat(rag_pipeline, rag_chain, args.no_stream)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -94,9 +94,7 @@ class RAG:
) )
# Create embeddings # Create embeddings
print("Before embedding")
embeddings = self.ollama.embeddings embeddings = self.ollama.embeddings
print("after embedding")
# Load or create vector store # Load or create vector store
if os.path.exists(self.db_dir) and not force_refresh: if os.path.exists(self.db_dir) and not force_refresh:
@ -256,24 +254,31 @@ class RAG:
if stream: if stream:
response = rag_chain.stream( response = rag_chain.stream(
{"input": query}, {"input": query},
config={ config={"configurable": {"session_id": self.session_id}},
"configurable": {"session_id": self.session_id}
},
) )
# Store in chat history after getting full response # Store in chat history after getting full response
full_response = "" full_response = ""
for chunk in response: for chunk in response:
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk) # Extract only the LLM answer chunk
full_response += chunk_text if "answer" in chunk:
yield chunk_text chunk_text = (
chunk["answer"].content
if hasattr(chunk["answer"], "content")
else str(chunk["answer"])
)
full_response += chunk_text
yield chunk_text
else: else:
response = rag_chain.invoke( response = rag_chain.invoke(
{"input": query}, {"input": query},
config={ config={"configurable": {"session_id": self.session_id}},
"configurable": {"session_id": self.session_id} )
}, # Extract only the LLM answer from the response
result = (
response["answer"].content
if hasattr(response["answer"], "content")
else str(response["answer"])
) )
result = response.content if hasattr(response, "content") else str(response)
return result return result