mirror of https://github.com/microsoft/autogen.git
Support setting vector_db as a param (#2313)
* Added vectordb base and chromadb * Remove timer and unused functions * Added filter by distance * Added test utils * Fix format * Fix type hint of dict * Rename test * Add test chromadb * Fix test no chromadb * Add coverage * Don't skip test vectordb utils * Add types * Fix tests * Fix docs build error * Add types to base * Update base * Update utils * Update chromadb * Add get_docs_by_ids * Improve docstring * Update init params * Update init vector db * Add get all docs * Move chroma_results_to_query_results to utils * Add init vectordb * Convert format of results for old version * Improve type hints * Update get_context for new query results format * Fix typo * Improve init db * Update default folder * Update logger * Update init, add embedding func * Update distance_threshold * Fix logger name * Update qdrant * Fix init db * Update notebooks * Use kwargs to improve readability * Improve docstring of vectordb, add two attributes * Add db_config * Update gitignore * Update comments * Add source * Fix file downloaded from urls have the same name * Remove files added by mistake * Improve docstring * Update docstring Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update docstring * Update docstring --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
4ab8a88487
commit
c4e570393d
|
@ -183,6 +183,7 @@ test/agentchat/test_agent_scripts/*
|
|||
# test cache
|
||||
.cache_test
|
||||
.db
|
||||
local_cache
|
||||
|
||||
|
||||
notebook/result.png
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
import logging
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
|
||||
from autogen.agentchat.contrib.vectordb.utils import (
|
||||
chroma_results_to_query_results,
|
||||
filter_results_by_distance,
|
||||
get_logger,
|
||||
)
|
||||
from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import fastembed
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client.fastembed_common import QueryResponse
|
||||
except ImportError as e:
|
||||
logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
|
||||
logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -136,6 +140,11 @@ class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent):
|
|||
collection_name=self._collection_name,
|
||||
embedding_model=self._embedding_model,
|
||||
)
|
||||
results["contents"] = results.pop("documents")
|
||||
results = chroma_results_to_query_results(results, "distances")
|
||||
results = filter_results_by_distance(results, self._distance_threshold)
|
||||
|
||||
self._search_string = search_string
|
||||
self._results = results
|
||||
|
||||
|
||||
|
@ -298,6 +307,7 @@ def query_qdrant(
|
|||
data = {
|
||||
"ids": [[result.id for result in sublist] for sublist in results],
|
||||
"documents": [[result.document for result in sublist] for sublist in results],
|
||||
"distances": [[result.score for result in sublist] for sublist in results],
|
||||
"metadatas": [[result.metadata for result in sublist] for sublist in results],
|
||||
}
|
||||
return data
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -7,15 +9,28 @@ try:
|
|||
import chromadb
|
||||
except ImportError:
|
||||
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
|
||||
from autogen import logger
|
||||
from autogen.agentchat import UserProxyAgent
|
||||
from autogen.agentchat.agent import Agent
|
||||
from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory
|
||||
from autogen.agentchat.contrib.vectordb.utils import (
|
||||
chroma_results_to_query_results,
|
||||
filter_results_by_distance,
|
||||
get_logger,
|
||||
)
|
||||
from autogen.code_utils import extract_code
|
||||
from autogen.retrieve_utils import TEXT_FORMATS, create_vector_db_from_dir, query_vector_db
|
||||
from autogen.retrieve_utils import (
|
||||
TEXT_FORMATS,
|
||||
create_vector_db_from_dir,
|
||||
get_files_from_dir,
|
||||
query_vector_db,
|
||||
split_files_to_chunks,
|
||||
)
|
||||
from autogen.token_count_utils import count_token
|
||||
|
||||
from ...formatting_utils import colored
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the
|
||||
context provided by the user. You should follow the following steps to answer a question:
|
||||
Step 1, you estimate the user's intent based on the question and context. The intent can be a code generation task or
|
||||
|
@ -65,6 +80,8 @@ User's question is: {input_question}
|
|||
Context is: {input_context}
|
||||
"""
|
||||
|
||||
HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))
|
||||
|
||||
|
||||
class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
"""(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding
|
||||
|
@ -107,9 +124,17 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
"code", "qa" and "default". System prompt will be different for different tasks.
|
||||
The default value is `default`, which supports both code and qa, and provides
|
||||
source information in the end of the response.
|
||||
- `vector_db` (Optional, Union[str, VectorDB]) - the vector db for the retrieve chat.
|
||||
If it's a string, it should be the type of the vector db, such as "chroma"; otherwise,
|
||||
it should be an instance of the VectorDB protocol. Default is "chroma".
|
||||
Set `None` to use the deprecated `client`.
|
||||
- `db_config` (Optional, Dict) - the config for the vector db. Default is `{}`. Please make
|
||||
sure you understand the config for the vector db you are using, otherwise, leave it as `{}`.
|
||||
Only valid when `vector_db` is a string.
|
||||
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
|
||||
default client `chromadb.Client()` will be used. If you want to use other
|
||||
vector db, extend this class and override the `retrieve_docs` function.
|
||||
**Deprecated**: use `vector_db` instead.
|
||||
- `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It
|
||||
can also be the path to a single file, the url to a single file or a list
|
||||
of directories, files and urls. Default is None, which works only if the
|
||||
|
@ -123,8 +148,11 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
By default, "extra_docs" is set to false, starting document IDs from zero.
|
||||
This poses a risk as new documents might overwrite existing ones, potentially
|
||||
causing unintended loss or alteration of data in the collection.
|
||||
- `collection_name` (Optional, str) - the name of the collection.
|
||||
If key not provided, a default name `autogen-docs` will be used.
|
||||
**Deprecated**: use `new_docs` when use `vector_db` instead of `client`.
|
||||
- `new_docs` (Optional, bool) - when True, only adds new documents to the collection;
|
||||
when False, updates existing documents and adds new ones. Default is True.
|
||||
Document id is used to determine if a document is new or existing. By default, the
|
||||
id is the hash value of the content.
|
||||
- `model` (Optional, str) - the model to use for the retrieve chat.
|
||||
If key not provided, a default model `gpt-4` will be used.
|
||||
- `chunk_token_size` (Optional, int) - the chunk token size for the retrieve chat.
|
||||
|
@ -143,6 +171,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
models can be found at `https://www.sbert.net/docs/pretrained_models.html`.
|
||||
The default model is a fast model. If you want to use a high performance model,
|
||||
`all-mpnet-base-v2` is recommended.
|
||||
**Deprecated**: no need when use `vector_db` instead of `client`.
|
||||
- `embedding_function` (Optional, Callable) - the embedding function for creating the
|
||||
vector db. Default is None, SentenceTransformer with the given `embedding_model`
|
||||
will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
|
||||
|
@ -156,10 +185,14 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
`Update Context` will be triggered.
|
||||
- `update_context` (Optional, bool) - if False, will not apply `Update Context` for
|
||||
interactive retrieval. Default is True.
|
||||
- `get_or_create` (Optional, bool) - if True, will create/return a collection for the
|
||||
retrieve chat. This is the same as that used in chromadb.
|
||||
Default is False. Will raise ValueError if the collection already exists and
|
||||
get_or_create is False. Will be set to True if docs_path is None.
|
||||
- `collection_name` (Optional, str) - the name of the collection.
|
||||
If key not provided, a default name `autogen-docs` will be used.
|
||||
- `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is True.
|
||||
- `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False.
|
||||
Case 1. if the collection does not exist, create the collection.
|
||||
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
|
||||
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
|
||||
otherwise it raise a ValueError.
|
||||
- `custom_token_count_function` (Optional, Callable) - a custom function to count the
|
||||
number of tokens in a string.
|
||||
The function should take (text:str, model:str) as input and return the
|
||||
|
@ -176,6 +209,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
included files and urls will be chunked regardless of their types.
|
||||
- `recursive` (Optional, bool) - whether to search documents recursively in the
|
||||
docs_path. Default is True.
|
||||
- `distance_threshold` (Optional, float) - the threshold for the distance score, only
|
||||
distance smaller than it will be returned. Will be ignored if < 0. Default is -1.
|
||||
|
||||
`**kwargs` (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
|
||||
|
||||
|
@ -183,6 +218,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
|
||||
Example of overriding retrieve_docs - If you have set up a customized vector db, and it's
|
||||
not compatible with chromadb, you can easily plug in it with below code.
|
||||
**Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent.
|
||||
```python
|
||||
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
|
||||
def query_vector_db(
|
||||
|
@ -215,9 +251,12 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
|
||||
self._retrieve_config = {} if retrieve_config is None else retrieve_config
|
||||
self._task = self._retrieve_config.get("task", "default")
|
||||
self._vector_db = self._retrieve_config.get("vector_db", "chroma")
|
||||
self._db_config = self._retrieve_config.get("db_config", {})
|
||||
self._client = self._retrieve_config.get("client", chromadb.Client())
|
||||
self._docs_path = self._retrieve_config.get("docs_path", None)
|
||||
self._extra_docs = self._retrieve_config.get("extra_docs", False)
|
||||
self._new_docs = self._retrieve_config.get("new_docs", True)
|
||||
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
|
||||
if "docs_path" not in self._retrieve_config:
|
||||
logger.warning(
|
||||
|
@ -236,6 +275,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
|
||||
self.update_context = self._retrieve_config.get("update_context", True)
|
||||
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
|
||||
self._overwrite = self._retrieve_config.get("overwrite", False)
|
||||
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
|
||||
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
|
||||
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
|
||||
|
@ -244,18 +284,95 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
self._collection = True if self._docs_path is None else False # whether the collection is created
|
||||
self._ipython = get_ipython()
|
||||
self._doc_idx = -1 # the index of the current used doc
|
||||
self._results = {} # the results of the current query
|
||||
self._results = [] # the results of the current query
|
||||
self._intermediate_answers = set() # the intermediate answers
|
||||
self._doc_contents = [] # the contents of the current used doc
|
||||
self._doc_ids = [] # the ids of the current used doc
|
||||
self._current_docs_in_context = [] # the ids of the current context sources
|
||||
self._search_string = "" # the search string used in the current query
|
||||
self._distance_threshold = self._retrieve_config.get("distance_threshold", -1)
|
||||
# update the termination message function
|
||||
self._is_termination_msg = (
|
||||
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
|
||||
)
|
||||
if isinstance(self._vector_db, str):
|
||||
if not isinstance(self._db_config, dict):
|
||||
raise ValueError("`db_config` should be a dictionary.")
|
||||
if "embedding_function" in self._retrieve_config:
|
||||
self._db_config["embedding_function"] = self._embedding_function
|
||||
self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config)
|
||||
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2)
|
||||
|
||||
def _init_db(self):
|
||||
if not self._vector_db:
|
||||
return
|
||||
|
||||
IS_TO_CHUNK = False # whether to chunk the raw files
|
||||
if self._new_docs:
|
||||
IS_TO_CHUNK = True
|
||||
if not self._docs_path:
|
||||
try:
|
||||
self._vector_db.get_collection(self._collection_name)
|
||||
logger.warning(f"`docs_path` is not provided. Use the existing collection `{self._collection_name}`.")
|
||||
self._overwrite = False
|
||||
self._get_or_create = True
|
||||
IS_TO_CHUNK = False
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"`docs_path` is not provided. "
|
||||
f"The collection `{self._collection_name}` doesn't exist either. "
|
||||
"Please provide `docs_path` or create the collection first."
|
||||
)
|
||||
elif self._get_or_create and not self._overwrite:
|
||||
try:
|
||||
self._vector_db.get_collection(self._collection_name)
|
||||
logger.info(f"Use the existing collection `{self._collection_name}`.", color="green")
|
||||
except ValueError:
|
||||
IS_TO_CHUNK = True
|
||||
else:
|
||||
IS_TO_CHUNK = True
|
||||
|
||||
self._vector_db.active_collection = self._vector_db.create_collection(
|
||||
self._collection_name, overwrite=self._overwrite, get_or_create=self._get_or_create
|
||||
)
|
||||
|
||||
docs = None
|
||||
if IS_TO_CHUNK:
|
||||
if self.custom_text_split_function is not None:
|
||||
chunks, sources = split_files_to_chunks(
|
||||
get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
|
||||
custom_text_split_function=self.custom_text_split_function,
|
||||
)
|
||||
else:
|
||||
chunks, sources = split_files_to_chunks(
|
||||
get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
|
||||
self._max_tokens,
|
||||
self._chunk_mode,
|
||||
self._must_break_at_empty_line,
|
||||
)
|
||||
logger.info(f"Found {len(chunks)} chunks.")
|
||||
|
||||
if self._new_docs:
|
||||
all_docs_ids = set(
|
||||
[
|
||||
doc["id"]
|
||||
for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name)
|
||||
]
|
||||
)
|
||||
else:
|
||||
all_docs_ids = set()
|
||||
|
||||
chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
|
||||
chunk_ids_set = set(chunk_ids)
|
||||
chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
|
||||
docs = [
|
||||
Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx])
|
||||
for idx in chunk_ids_set_idx
|
||||
if chunk_ids[idx] not in all_docs_ids
|
||||
]
|
||||
|
||||
self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True)
|
||||
|
||||
def _is_termination_msg_retrievechat(self, message):
|
||||
"""Check if a message is a termination message.
|
||||
For code generation, terminate when no code block is detected. Currently only detect python code blocks.
|
||||
|
@ -288,41 +405,42 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
|
||||
def _reset(self, intermediate=False):
|
||||
self._doc_idx = -1 # the index of the current used doc
|
||||
self._results = {} # the results of the current query
|
||||
self._results = [] # the results of the current query
|
||||
if not intermediate:
|
||||
self._intermediate_answers = set() # the intermediate answers
|
||||
self._doc_contents = [] # the contents of the current used doc
|
||||
self._doc_ids = [] # the ids of the current used doc
|
||||
|
||||
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
|
||||
def _get_context(self, results: QueryResults):
|
||||
doc_contents = ""
|
||||
self._current_docs_in_context = []
|
||||
current_tokens = 0
|
||||
_doc_idx = self._doc_idx
|
||||
_tmp_retrieve_count = 0
|
||||
for idx, doc in enumerate(results["documents"][0]):
|
||||
for idx, doc in enumerate(results[0]):
|
||||
doc = doc[0]
|
||||
if idx <= _doc_idx:
|
||||
continue
|
||||
if results["ids"][0][idx] in self._doc_ids:
|
||||
if doc["id"] in self._doc_ids:
|
||||
continue
|
||||
_doc_tokens = self.custom_token_count_function(doc, self._model)
|
||||
_doc_tokens = self.custom_token_count_function(doc["content"], self._model)
|
||||
if _doc_tokens > self._context_max_tokens:
|
||||
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
|
||||
func_print = f"Skip doc_id {doc['id']} as it is too long to fit in the context."
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
self._doc_idx = idx
|
||||
continue
|
||||
if current_tokens + _doc_tokens > self._context_max_tokens:
|
||||
break
|
||||
func_print = f"Adding doc_id {results['ids'][0][idx]} to context."
|
||||
func_print = f"Adding content of doc {doc['id']} to context."
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
current_tokens += _doc_tokens
|
||||
doc_contents += doc + "\n"
|
||||
_metadatas = results.get("metadatas")
|
||||
if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict):
|
||||
self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", ""))
|
||||
doc_contents += doc["content"] + "\n"
|
||||
_metadata = doc.get("metadata")
|
||||
if isinstance(_metadata, dict):
|
||||
self._current_docs_in_context.append(_metadata.get("source", ""))
|
||||
self._doc_idx = idx
|
||||
self._doc_ids.append(results["ids"][0][idx])
|
||||
self._doc_contents.append(doc)
|
||||
self._doc_ids.append(doc["id"])
|
||||
self._doc_contents.append(doc["content"])
|
||||
_tmp_retrieve_count += 1
|
||||
if _tmp_retrieve_count >= self.n_results:
|
||||
break
|
||||
|
@ -416,21 +534,40 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
|
||||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
|
||||
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
|
||||
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
|
||||
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
|
||||
parameters and add your own parameters with default values, and keep the results in below type.
|
||||
|
||||
Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
|
||||
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
|
||||
to `chromadb.api.types.QueryResult` as an example.
|
||||
ids: List[string]
|
||||
documents: List[List[string]]
|
||||
The retrieved docs should be type of `QueryResults` which is a list of tuples containing the document and
|
||||
the distance.
|
||||
|
||||
Args:
|
||||
problem (str): the problem to be solved.
|
||||
n_results (int): the number of results to be retrieved. Default is 20.
|
||||
search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "".
|
||||
Not used if the vector_db doesn't support it.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if isinstance(self._vector_db, VectorDB):
|
||||
if not self._collection or not self._get_or_create:
|
||||
print("Trying to create collection.")
|
||||
self._init_db()
|
||||
self._collection = True
|
||||
self._get_or_create = True
|
||||
|
||||
kwargs = {}
|
||||
if hasattr(self._vector_db, "type") and self._vector_db.type == "chroma":
|
||||
kwargs["where_document"] = {"$contains": search_string} if search_string else None
|
||||
results = self._vector_db.retrieve_docs(
|
||||
queries=[problem],
|
||||
n_results=n_results,
|
||||
collection_name=self._collection_name,
|
||||
distance_threshold=self._distance_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
self._search_string = search_string
|
||||
self._results = results
|
||||
print("VectorDB returns doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
|
||||
return
|
||||
|
||||
if not self._collection or not self._get_or_create:
|
||||
print("Trying to create collection.")
|
||||
self._client = create_vector_db_from_dir(
|
||||
|
@ -460,9 +597,13 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
embedding_model=self._embedding_model,
|
||||
embedding_function=self._embedding_function,
|
||||
)
|
||||
results["contents"] = results.pop("documents")
|
||||
results = chroma_results_to_query_results(results, "distances")
|
||||
results = filter_results_by_distance(results, self._distance_threshold)
|
||||
|
||||
self._search_string = search_string
|
||||
self._results = results
|
||||
print("doc_ids: ", results["ids"])
|
||||
print("doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
|
||||
|
||||
@staticmethod
|
||||
def message_generator(sender, recipient, context):
|
||||
|
|
|
@ -24,7 +24,7 @@ class ChromaVectorDB(VectorDB):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs
|
||||
self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the vector database.
|
||||
|
@ -32,7 +32,7 @@ class ChromaVectorDB(VectorDB):
|
|||
Args:
|
||||
client: chromadb.Client | The client object of the vector database. Default is None.
|
||||
If provided, it will use the client object directly and ignore other arguments.
|
||||
path: str | The path to the vector database. Default is None.
|
||||
path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24.
|
||||
embedding_function: Callable | The embedding function used to generate the vector representation
|
||||
of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
|
||||
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
|
||||
|
|
|
@ -25,6 +25,9 @@ class ColoredLogger(logging.Logger):
|
|||
def critical(self, msg, *args, color="red", **kwargs):
|
||||
super().critical(colored(msg, color), *args, **kwargs)
|
||||
|
||||
def fatal(self, msg, *args, color="red", **kwargs):
|
||||
super().fatal(colored(msg, color), *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger:
|
||||
logger = ColoredLogger(name, level)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import glob
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
@ -156,7 +157,7 @@ def split_files_to_chunks(
|
|||
chunk_mode: str = "multi_lines",
|
||||
must_break_at_empty_line: bool = True,
|
||||
custom_text_split_function: Callable = None,
|
||||
):
|
||||
) -> Tuple[List[str], List[dict]]:
|
||||
"""Split a list of files into chunks of max_tokens."""
|
||||
|
||||
chunks = []
|
||||
|
@ -275,15 +276,22 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
|
|||
return webpage_text
|
||||
|
||||
|
||||
def _generate_file_name_from_url(url: str, max_length=255) -> str:
|
||||
url_bytes = url.encode("utf-8")
|
||||
hash = hashlib.blake2b(url_bytes).hexdigest()
|
||||
parsed_url = urlparse(url)
|
||||
file_name = os.path.basename(url)
|
||||
file_name = f"{parsed_url.netloc}_{file_name}_{hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}"
|
||||
return file_name
|
||||
|
||||
|
||||
def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
|
||||
"""Download a file from a URL."""
|
||||
if save_path is None:
|
||||
save_path = "tmp/chromadb"
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
if os.path.isdir(save_path):
|
||||
filename = os.path.basename(url)
|
||||
if filename == "": # "www.example.com/"
|
||||
filename = url.split("/")[-2]
|
||||
filename = _generate_file_name_from_url(url)
|
||||
save_path = os.path.join(save_path, filename)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
@ -327,7 +335,7 @@ def create_vector_db_from_dir(
|
|||
dir_path: Union[str, List[str]],
|
||||
max_tokens: int = 4000,
|
||||
client: API = None,
|
||||
db_path: str = "/tmp/chromadb.db",
|
||||
db_path: str = "tmp/chromadb.db",
|
||||
collection_name: str = "all-my-documents",
|
||||
get_or_create: bool = False,
|
||||
chunk_mode: str = "multi_lines",
|
||||
|
@ -347,7 +355,7 @@ def create_vector_db_from_dir(
|
|||
dir_path (Union[str, List[str]]): the path to the directory, file, url or a list of them.
|
||||
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
|
||||
client (Optional, API): the chromadb client. Default is None.
|
||||
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
|
||||
db_path (Optional, str): the path to the chromadb. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
|
||||
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
|
||||
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
|
||||
will be returned if it already exists. Will raise ValueError if the collection already exists and get_or_create is False.
|
||||
|
@ -420,7 +428,7 @@ def query_vector_db(
|
|||
query_texts: List[str],
|
||||
n_results: int = 10,
|
||||
client: API = None,
|
||||
db_path: str = "/tmp/chromadb.db",
|
||||
db_path: str = "tmp/chromadb.db",
|
||||
collection_name: str = "all-my-documents",
|
||||
search_string: str = "",
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
|
@ -433,7 +441,7 @@ def query_vector_db(
|
|||
query_texts (List[str]): the list of strings which will be used to query the vector db.
|
||||
n_results (Optional, int): the number of results to return. Default is 10.
|
||||
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
|
||||
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
|
||||
db_path (Optional, str): the path to the vector db. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
|
||||
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
|
||||
search_string (Optional, str): the search string. Only docs that contain an exact match of this string will be retrieved. Default is "".
|
||||
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue