Fix outputs to only include llm model answers
All checks were successful
Pytest / Explore-Gitea-Actions (push) Successful in -13m41s
All checks were successful
Pytest / Explore-Gitea-Actions (push) Successful in -13m41s
This commit is contained in:
parent
96801ba8e8
commit
121e815f52
@ -14,27 +14,32 @@ def stream_output(response_iter):
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nStreaming interrupted by user")
|
print("\nStreaming interrupted by user")
|
||||||
|
|
||||||
|
|
||||||
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
|
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
|
||||||
"""Run an interactive chat session"""
|
"""Run an interactive chat session"""
|
||||||
print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n")
|
print(
|
||||||
|
"\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
query = input("\nQuestion: ").strip()
|
query = input("\nQuestion: ").strip()
|
||||||
|
|
||||||
if query.lower() in ['exit', 'quit', '']:
|
if query.lower() in ["exit", "quit", ""]:
|
||||||
print("\nEnding chat session.")
|
print("\nEnding chat session.")
|
||||||
break
|
break
|
||||||
|
|
||||||
print("\nResponse:")
|
print("\nResponse:")
|
||||||
if no_stream:
|
if no_stream:
|
||||||
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
|
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
|
||||||
print(response)
|
print(response)
|
||||||
else:
|
else:
|
||||||
response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True)
|
response_iter = rag_pipeline.query_rag(
|
||||||
|
rag_chain, query, stream=True
|
||||||
|
)
|
||||||
stream_output(response_iter)
|
stream_output(response_iter)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n\nEnding chat session.")
|
print("\n\nEnding chat session.")
|
||||||
break
|
break
|
||||||
@ -43,66 +48,63 @@ def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nChat session interrupted.")
|
print("\nChat session interrupted.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Code RAG - Query your codebase using natural language"
|
description="Code RAG - Query your codebase using natural language"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"docs_dir",
|
"docs_dir", help="Directory containing the documents to process", type=str
|
||||||
help="Directory containing the documents to process",
|
|
||||||
type=str
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"query",
|
"query",
|
||||||
help="Initial query about your codebase (optional in interactive mode)",
|
help="Initial query about your codebase (optional in interactive mode)",
|
||||||
nargs="?",
|
nargs="?",
|
||||||
default=None,
|
default=None,
|
||||||
type=str
|
type=str,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-dir",
|
"--db-dir",
|
||||||
help="Directory to store the vector database (default: .code_rag_db)",
|
help="Directory to store the vector database (default: .code_rag_db)",
|
||||||
default=os.path.expanduser("~/.code_rag_db"),
|
default=os.path.expanduser("~/.code_rag_db"),
|
||||||
type=str
|
type=str,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tracker-file",
|
"--tracker-file",
|
||||||
help="File to track document changes (default: .code_rag_tracker.json)",
|
help="File to track document changes (default: .code_rag_tracker.json)",
|
||||||
default=os.path.expanduser("~/.code_rag_tracker.json"),
|
default=os.path.expanduser("~/.code_rag_tracker.json"),
|
||||||
type=str
|
type=str,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--force-refresh",
|
"--force-refresh",
|
||||||
help="Force refresh of the vector database",
|
help="Force refresh of the vector database",
|
||||||
action="store_true"
|
action="store_true",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ollama-url",
|
"--ollama-url",
|
||||||
help="URL for the Ollama server (default: 127.0.0.1)",
|
help="URL for the Ollama server (default: 127.0.0.1)",
|
||||||
default="127.0.0.1",
|
default="127.0.0.1",
|
||||||
type=str
|
type=str,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
help="Model to use for embeddings (default: nomic-embed-text)",
|
help="Model to use for embeddings (default: nomic-embed-text)",
|
||||||
default="nomic-embed-text",
|
default="nomic-embed-text",
|
||||||
type=str
|
type=str,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-model",
|
"--llm-model",
|
||||||
help="Model to use for text generation (default: llama3.2)",
|
help="Model to use for text generation (default: llama3.2)",
|
||||||
default="llama3.2",
|
default="llama3.2",
|
||||||
type=str
|
type=str,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-stream",
|
"--no-stream", help="Disable streaming output", action="store_true"
|
||||||
help="Disable streaming output",
|
|
||||||
action="store_true"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-interactive",
|
"--no-interactive",
|
||||||
help="Run in non-interactive mode (answer single query and exit)",
|
help="Run in non-interactive mode (answer single query and exit)",
|
||||||
action="store_true"
|
action="store_true",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -117,7 +119,7 @@ def main():
|
|||||||
tracker_file=args.tracker_file,
|
tracker_file=args.tracker_file,
|
||||||
ollama_url=args.ollama_url,
|
ollama_url=args.ollama_url,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
llm_model=args.llm_model
|
llm_model=args.llm_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create or update vector database
|
# Create or update vector database
|
||||||
@ -136,7 +138,7 @@ def main():
|
|||||||
print("\nResponse:")
|
print("\nResponse:")
|
||||||
if args.no_stream:
|
if args.no_stream:
|
||||||
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||||
print(response)
|
stream_output(response)
|
||||||
else:
|
else:
|
||||||
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
||||||
stream_output(response_iter)
|
stream_output(response_iter)
|
||||||
@ -148,13 +150,16 @@ def main():
|
|||||||
print("\nResponse:")
|
print("\nResponse:")
|
||||||
if args.no_stream:
|
if args.no_stream:
|
||||||
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||||
print(response)
|
stream_output(response)
|
||||||
else:
|
else:
|
||||||
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
response_iter = rag_pipeline.query_rag(
|
||||||
|
rag_chain, args.query, stream=True
|
||||||
|
)
|
||||||
stream_output(response_iter)
|
stream_output(response_iter)
|
||||||
|
|
||||||
# Start interactive chat
|
# Start interactive chat
|
||||||
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
|
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -94,9 +94,7 @@ class RAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create embeddings
|
# Create embeddings
|
||||||
print("Before embedding")
|
|
||||||
embeddings = self.ollama.embeddings
|
embeddings = self.ollama.embeddings
|
||||||
print("after embedding")
|
|
||||||
|
|
||||||
# Load or create vector store
|
# Load or create vector store
|
||||||
if os.path.exists(self.db_dir) and not force_refresh:
|
if os.path.exists(self.db_dir) and not force_refresh:
|
||||||
@ -256,24 +254,31 @@ class RAG:
|
|||||||
if stream:
|
if stream:
|
||||||
response = rag_chain.stream(
|
response = rag_chain.stream(
|
||||||
{"input": query},
|
{"input": query},
|
||||||
config={
|
config={"configurable": {"session_id": self.session_id}},
|
||||||
"configurable": {"session_id": self.session_id}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
# Store in chat history after getting full response
|
# Store in chat history after getting full response
|
||||||
full_response = ""
|
full_response = ""
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
|
# Extract only the LLM answer chunk
|
||||||
full_response += chunk_text
|
if "answer" in chunk:
|
||||||
yield chunk_text
|
chunk_text = (
|
||||||
|
chunk["answer"].content
|
||||||
|
if hasattr(chunk["answer"], "content")
|
||||||
|
else str(chunk["answer"])
|
||||||
|
)
|
||||||
|
full_response += chunk_text
|
||||||
|
yield chunk_text
|
||||||
else:
|
else:
|
||||||
response = rag_chain.invoke(
|
response = rag_chain.invoke(
|
||||||
{"input": query},
|
{"input": query},
|
||||||
config={
|
config={"configurable": {"session_id": self.session_id}},
|
||||||
"configurable": {"session_id": self.session_id}
|
)
|
||||||
},
|
# Extract only the LLM answer from the response
|
||||||
|
result = (
|
||||||
|
response["answer"].content
|
||||||
|
if hasattr(response["answer"], "content")
|
||||||
|
else str(response["answer"])
|
||||||
)
|
)
|
||||||
result = response.content if hasattr(response, "content") else str(response)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user