Fix outputs to only include llm model answers
All checks were successful
Pytest / Explore-Gitea-Actions (push) Successful in -13m41s
All checks were successful
Pytest / Explore-Gitea-Actions (push) Successful in -13m41s
This commit is contained in:
parent
96801ba8e8
commit
121e815f52
@ -14,27 +14,32 @@ def stream_output(response_iter):
|
||||
except KeyboardInterrupt:
|
||||
print("\nStreaming interrupted by user")
|
||||
|
||||
|
||||
def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
|
||||
"""Run an interactive chat session"""
|
||||
print("\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n")
|
||||
|
||||
print(
|
||||
"\nEnter your questions about the codebase. Type 'exit', 'quit', or press Ctrl+C to end the session.\n"
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
query = input("\nQuestion: ").strip()
|
||||
|
||||
if query.lower() in ['exit', 'quit', '']:
|
||||
|
||||
if query.lower() in ["exit", "quit", ""]:
|
||||
print("\nEnding chat session.")
|
||||
break
|
||||
|
||||
|
||||
print("\nResponse:")
|
||||
if no_stream:
|
||||
response = rag_pipeline.query_rag(rag_chain, query, stream=False)
|
||||
print(response)
|
||||
else:
|
||||
response_iter = rag_pipeline.query_rag(rag_chain, query, stream=True)
|
||||
response_iter = rag_pipeline.query_rag(
|
||||
rag_chain, query, stream=True
|
||||
)
|
||||
stream_output(response_iter)
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nEnding chat session.")
|
||||
break
|
||||
@ -43,66 +48,63 @@ def interactive_chat(rag_pipeline, rag_chain, no_stream=False):
|
||||
except KeyboardInterrupt:
|
||||
print("\nChat session interrupted.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Code RAG - Query your codebase using natural language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"docs_dir",
|
||||
help="Directory containing the documents to process",
|
||||
type=str
|
||||
"docs_dir", help="Directory containing the documents to process", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"query",
|
||||
help="Initial query about your codebase (optional in interactive mode)",
|
||||
nargs="?",
|
||||
default=None,
|
||||
type=str
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-dir",
|
||||
help="Directory to store the vector database (default: .code_rag_db)",
|
||||
default=os.path.expanduser("~/.code_rag_db"),
|
||||
type=str
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker-file",
|
||||
help="File to track document changes (default: .code_rag_tracker.json)",
|
||||
default=os.path.expanduser("~/.code_rag_tracker.json"),
|
||||
type=str
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-refresh",
|
||||
help="Force refresh of the vector database",
|
||||
action="store_true"
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-url",
|
||||
help="URL for the Ollama server (default: 127.0.0.1)",
|
||||
default="127.0.0.1",
|
||||
type=str
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
help="Model to use for embeddings (default: nomic-embed-text)",
|
||||
default="nomic-embed-text",
|
||||
type=str
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model",
|
||||
help="Model to use for text generation (default: llama3.2)",
|
||||
default="llama3.2",
|
||||
type=str
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
help="Disable streaming output",
|
||||
action="store_true"
|
||||
"--no-stream", help="Disable streaming output", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-interactive",
|
||||
help="Run in non-interactive mode (answer single query and exit)",
|
||||
action="store_true"
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -117,7 +119,7 @@ def main():
|
||||
tracker_file=args.tracker_file,
|
||||
ollama_url=args.ollama_url,
|
||||
embedding_model=args.embedding_model,
|
||||
llm_model=args.llm_model
|
||||
llm_model=args.llm_model,
|
||||
)
|
||||
|
||||
# Create or update vector database
|
||||
@ -136,7 +138,7 @@ def main():
|
||||
print("\nResponse:")
|
||||
if args.no_stream:
|
||||
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||
print(response)
|
||||
stream_output(response)
|
||||
else:
|
||||
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
||||
stream_output(response_iter)
|
||||
@ -148,13 +150,16 @@ def main():
|
||||
print("\nResponse:")
|
||||
if args.no_stream:
|
||||
response = rag_pipeline.query_rag(rag_chain, args.query, stream=False)
|
||||
print(response)
|
||||
stream_output(response)
|
||||
else:
|
||||
response_iter = rag_pipeline.query_rag(rag_chain, args.query, stream=True)
|
||||
response_iter = rag_pipeline.query_rag(
|
||||
rag_chain, args.query, stream=True
|
||||
)
|
||||
stream_output(response_iter)
|
||||
|
||||
|
||||
# Start interactive chat
|
||||
interactive_chat(rag_pipeline, rag_chain, args.no_stream)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -94,9 +94,7 @@ class RAG:
|
||||
)
|
||||
|
||||
# Create embeddings
|
||||
print("Before embedding")
|
||||
embeddings = self.ollama.embeddings
|
||||
print("after embedding")
|
||||
|
||||
# Load or create vector store
|
||||
if os.path.exists(self.db_dir) and not force_refresh:
|
||||
@ -256,24 +254,31 @@ class RAG:
|
||||
if stream:
|
||||
response = rag_chain.stream(
|
||||
{"input": query},
|
||||
config={
|
||||
"configurable": {"session_id": self.session_id}
|
||||
},
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
)
|
||||
# Store in chat history after getting full response
|
||||
full_response = ""
|
||||
for chunk in response:
|
||||
chunk_text = chunk.content if hasattr(chunk, "content") else str(chunk)
|
||||
full_response += chunk_text
|
||||
yield chunk_text
|
||||
# Extract only the LLM answer chunk
|
||||
if "answer" in chunk:
|
||||
chunk_text = (
|
||||
chunk["answer"].content
|
||||
if hasattr(chunk["answer"], "content")
|
||||
else str(chunk["answer"])
|
||||
)
|
||||
full_response += chunk_text
|
||||
yield chunk_text
|
||||
else:
|
||||
response = rag_chain.invoke(
|
||||
{"input": query},
|
||||
config={
|
||||
"configurable": {"session_id": self.session_id}
|
||||
},
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
)
|
||||
# Extract only the LLM answer from the response
|
||||
result = (
|
||||
response["answer"].content
|
||||
if hasattr(response["answer"], "content")
|
||||
else str(response["answer"])
|
||||
)
|
||||
result = response.content if hasattr(response, "content") else str(response)
|
||||
return result
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user