diff --git a/src/code_rag/rag.py b/src/code_rag/rag.py index 95cbed1..e1fd5f8 100644 --- a/src/code_rag/rag.py +++ b/src/code_rag/rag.py @@ -66,7 +66,7 @@ class RAG: def create_vector_db(self, extensions=None, force_refresh=False): """ Create or update a vector database, with complete handling of changes. - + Args: extensions (list[str], optional): List of file extensions to include. If None, defaults to common programming languages. @@ -74,18 +74,7 @@ class RAG: """ # Set default extensions for common programming languages if none provided if extensions is None: - extensions = [ - # Python - '.py', '.pyi', '.pyx', - # C/C++ - '.c', '.cpp', '.cc', '.cxx', '.h', '.hpp', '.hxx', - # Rust - '.rs', - # Documentation - '.txt', '.md', - # Build/Config - '.toml', '.yaml', '.json' - ] + extensions = default_file_extensions() elif isinstance(extensions, str): extensions = [extensions] @@ -103,14 +92,12 @@ class RAG: if os.path.exists(self.db_dir) and not force_refresh: print("Loading existing vector store") vectorstore = Chroma( - persist_directory=self.db_dir, - embedding_function=embeddings + persist_directory=self.db_dir, embedding_function=embeddings ) else: print("Creating new vector store") vectorstore = Chroma( - persist_directory=self.db_dir, - embedding_function=embeddings + persist_directory=self.db_dir, embedding_function=embeddings ) # Find all files that match the extensions @@ -127,14 +114,14 @@ class RAG: all_documents = [] for file_path in all_files: try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: content = f.read() doc = Document( page_content=content, metadata={ "source": os.path.abspath(file_path), - "source_id": os.path.abspath(file_path) - } + "source_id": os.path.abspath(file_path), + }, ) all_documents.append(doc) print(f"Successfully loaded {file_path}") @@ -151,9 +138,7 @@ class RAG: if force_refresh: # Create new vector store from scratch vectorstore = Chroma.from_documents( - documents=chunks, - embedding=embeddings, - persist_directory=self.db_dir + documents=chunks, embedding=embeddings, persist_directory=self.db_dir ) else: # Update existing vector store @@ -163,7 +148,8 @@ class RAG: # Find new documents new_chunks = [ - chunk for chunk in chunks + chunk + for chunk in chunks if chunk.metadata["source"] not in existing_sources ] @@ -224,3 +210,136 @@ class RAG: """ response = rag_chain.invoke({"query": query}) return response["result"] + + +def default_file_extensions(): + """Return the default file extensions representing common plain text and code files""" + return [ + # Python + ".py", + ".pyi", + ".pyx", + ".pyc", + ".pyd", + ".pyw", + # C/C++ + ".c", + ".cpp", + ".cc", + ".cxx", + ".h", + ".hpp", + ".hxx", + ".inc", + ".inl", + ".ipp", + # Rust + ".rs", + ".rlib", + ".rmeta", + # Java + ".java", + ".jsp", + ".jav", + ".jar", + ".class", + ".kt", + ".kts", + ".groovy", + # Web + ".html", + ".htm", + ".css", + ".scss", + ".sass", + ".less", + ".js", + ".jsx", + ".ts", + ".tsx", + ".vue", + ".svelte", + # Fortran + ".f", + ".for", + ".f90", + ".f95", + ".f03", + ".f08", + # Go + ".go", + ".mod", + # Ruby + ".rb", + ".rbw", + ".rake", + ".gemspec", + # PHP + ".php", + ".phtml", + ".php3", + ".php4", + ".php5", + ".phps", + # C# + ".cs", + ".csx", + ".vb", + # Swift + ".swift", + ".swiftmodule", + # Shell/Scripts + ".sh", + ".bash", + ".zsh", + ".fish", + ".ps1", + ".bat", + ".cmd", + # Scala + ".scala", + ".sc", + # Haskell + ".hs", + ".lhs", + ".hsc", + # Lua + ".lua", + ".luac", + # R + ".r", + ".rmd", + ".rds", + # Perl + ".pl", + ".pm", + ".t", + # Documentation + ".txt", + ".md", + ".rst", + ".adoc", + ".wiki", + # Build/Config + ".toml", + ".yaml", + ".yml", + ".json", + ".xml", + ".ini", + ".conf", + ".cfg", + # SQL + ".sql", + ".mysql", + ".pgsql", + ".sqlite", + # Lisp family + ".lisp", + ".cl", + ".el", + ".clj", + ".cljc", + ".cljs", + ".edn", + ]