Fix outputs to only include llm model answers
All checks were successful
Pytest / Explore-Gitea-Actions (push) Successful in -13m41s

This commit is contained in:
Alex Selimov 2025-03-23 00:23:35 -04:00
parent 96801ba8e8
commit 121e815f52
2 changed files with 48 additions and 38 deletions

View File

@ -14,27 +14,32 @@ def stream_output(response_iter):
except KeyboardInterrupt:
print("\nStreaming interrupted by user")
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
"""Run an interactive chat session"""
print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n")
print(
"\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n"
)
try:
while True:
try:
query = input("\nQuestion: ").strip()
if query.lower() in ['exit', 'quit', '']:
if query.lower() in ["exit", "quit", ""]:
print("\nEnding chat session.")
break
print("\nResponse:")
if no_stream:
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
print(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True)
response_iter = rag_pipeline.query_rag(
rag_chain, query, stream=True
)
stream_output(response_iter)
except KeyboardInterrupt:
print("\n\nEnding chat session.")
break
@ -43,66 +48,63 @@ def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
except KeyboardInterrupt:
print("\nChat session interrupted.")
def main():
parser = argparse.ArgumentParser(
description="Code RAG - Query your codebase using natural language"
)
parser.add_argument(
"docs_dir",
help="Directory containing the documents to process",
type=str
"docs_dir", help="Directory containing the documents to process", type=str
)
parser.add_argument(
"query",
help="Initial query about your codebase (optional in interactive mode)",
nargs="?",
default=None,
type=str
type=str,
)
parser.add_argument(
"--db-dir",
help="Directory to store the vector database (default: .code_rag_db)",
default=os.path.expanduser("~/.code_rag_db"),
type=str
type=str,
)
parser.add_argument(
"--tracker-file",
help="File to track document changes (default: .code_rag_tracker.json)",
default=os.path.expanduser("~/.code_rag_tracker.json"),
type=str
type=str,
)
parser.add_argument(
"--force-refresh",
help="Force refresh of the vector database",
action="store_true"
action="store_true",
)
parser.add_argument(
"--ollama-url",
help="URL for the Ollama server (default: 127.0.0.1)",
default="127.0.0.1",
type=str
type=str,
)
parser.add_argument(
"--embedding-model",
help="Model to use for embeddings (default: nomic-embed-text)",
default="nomic-embed-text",
type=str
type=str,
)
parser.add_argument(
"--llm-model",
help="Model to use for text generation (default: llama3.2)",
default="llama3.2",
type=str
type=str,
)
parser.add_argument(
"--no-stream",
help="Disable streaming output",
action="store_true"
"--no-stream", help="Disable streaming output", action="store_true"
)
parser.add_argument(
"--no-interactive",
help="Run in non-interactive mode (answer single query and exit)",
action="store_true"
action="store_true",
)
args = parser.parse_args()
@ -117,7 +119,7 @@ def main():
tracker_file=args.tracker_file,
ollama_url=args.ollama_url,
embedding_model=args.embedding_model,
llm_model=args.llm_model
llm_model=args.llm_model,
)
# Create or update vector database
@ -136,7 +138,7 @@ def main():
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response)
stream_output(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
stream_output(response_iter)
@ -148,13 +150,16 @@ def main():
print("\nResponse:")
if args.no_stream:
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
print(response)
stream_output(response)
else:
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
response_iter = rag_pipeline.query_rag(
rag_chain, args.query, stream=True
)
stream_output(response_iter)
# Start interactive chat
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
if __name__ == "__main__":
main()

View File

@ -94,9 +94,7 @@ class RAG:
)
# Create embeddings
print("Before embedding")
embeddings = self.ollama.embeddings
print("after embedding")
# Load or create vector store
if os.path.exists(self.db_dir) and not force_refresh:
@ -256,24 +254,31 @@ class RAG:
if stream:
response = rag_chain.stream(
{"input": query},
config={
"configurable": {"session_id": self.session_id}
},
config={"configurable": {"session_id": self.session_id}},
)
# Store in chat history after getting full response
full_response = ""
for chunk in response:
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
full_response += chunk_text
yield chunk_text
# Extract only the LLM answer chunk
if "answer" in chunk:
chunk_text = (
chunk["answer"].content
if hasattr(chunk["answer"], "content")
else str(chunk["answer"])
)
full_response += chunk_text
yield chunk_text
else:
response = rag_chain.invoke(
{"input": query},
config={
"configurable": {"session_id": self.session_id}
},
config={"configurable": {"session_id": self.session_id}},
)
# Extract only the LLM answer from the response
result = (
response["answer"].content
if hasattr(response["answer"], "content")
else str(response["answer"])
)
result = response.content if hasattr(response, "content") else str(response)
return result