Support setting vector_db as a param (#2313)

* Added vectordb base and chromadb

* Remove timer and unused functions

* Added filter by distance

* Added test utils

* Fix format

* Fix type hint of dict

* Rename test

* Add test chromadb

* Fix test no chromadb

* Add coverage

* Don't skip test vectordb utils

* Add types

* Fix tests

* Fix docs build error

* Add types to base

* Update base

* Update utils

* Update chromadb

* Add get_docs_by_ids

* Improve docstring

* Update init params

* Update init vector db

* Add get all docs

* Move chroma_results_to_query_results to utils

* Add init vectordb

* Convert format of results for old version

* Improve type hints

* Update get_context for new query results format

* Fix typo

* Improve init db

* Update default folder

* Update logger

* Update init, add embedding func

* Update distance_threshold

* Fix logger name

* Update qdrant

* Fix init db

* Update notebooks

* Use kwargs to improve readability

* Improve docstring of vectordb, add two attributes

* Add db_config

* Update gitignore

* Update comments

* Add source

* Fix file downloaded from urls have the same name

* Remove files added by mistake

* Improve docstring

* Update docstring

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update docstring

* Update docstring

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Li Jiang 2024-04-17 16:30:05 +08:00 committed by GitHub
parent 4ab8a88487
commit c4e570393d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1388 additions and 1263 deletions

1
.gitignore vendored
View File

@ -183,6 +183,7 @@ test/agentchat/test_agent_scripts/*
# test cache
.cache_test
.db
local_cache
notebook/result.png

View File

@ -1,17 +1,21 @@
import logging
from typing import Callable, Dict, List, Optional
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from autogen.agentchat.contrib.vectordb.utils import (
chroma_results_to_query_results,
filter_results_by_distance,
get_logger,
)
from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
try:
import fastembed
from qdrant_client import QdrantClient, models
from qdrant_client.fastembed_common import QueryResponse
except ImportError as e:
logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
raise e
@ -136,6 +140,11 @@ class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent):
collection_name=self._collection_name,
embedding_model=self._embedding_model,
)
results["contents"] = results.pop("documents")
results = chroma_results_to_query_results(results, "distances")
results = filter_results_by_distance(results, self._distance_threshold)
self._search_string = search_string
self._results = results
@ -298,6 +307,7 @@ def query_qdrant(
data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
"distances": [[result.score for result in sublist] for sublist in results],
"metadatas": [[result.metadata for result in sublist] for sublist in results],
}
return data

View File

@ -1,3 +1,5 @@
import hashlib
import os
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -7,15 +9,28 @@ try:
import chromadb
except ImportError:
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from autogen import logger
from autogen.agentchat import UserProxyAgent
from autogen.agentchat.agent import Agent
from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory
from autogen.agentchat.contrib.vectordb.utils import (
chroma_results_to_query_results,
filter_results_by_distance,
get_logger,
)
from autogen.code_utils import extract_code
from autogen.retrieve_utils import TEXT_FORMATS, create_vector_db_from_dir, query_vector_db
from autogen.retrieve_utils import (
TEXT_FORMATS,
create_vector_db_from_dir,
get_files_from_dir,
query_vector_db,
split_files_to_chunks,
)
from autogen.token_count_utils import count_token
from ...formatting_utils import colored
logger = get_logger(__name__)
PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the
context provided by the user. You should follow the following steps to answer a question:
Step 1, you estimate the user's intent based on the question and context. The intent can be a code generation task or
@ -65,6 +80,8 @@ User's question is: {input_question}
Context is: {input_context}
"""
HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))
class RetrieveUserProxyAgent(UserProxyAgent):
"""(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding
@ -107,9 +124,17 @@ class RetrieveUserProxyAgent(UserProxyAgent):
"code", "qa" and "default". System prompt will be different for different tasks.
The default value is `default`, which supports both code and qa, and provides
source information in the end of the response.
- `vector_db` (Optional, Union[str, VectorDB]) - the vector db for the retrieve chat.
If it's a string, it should be the type of the vector db, such as "chroma"; otherwise,
it should be an instance of the VectorDB protocol. Default is "chroma".
Set `None` to use the deprecated `client`.
- `db_config` (Optional, Dict) - the config for the vector db. Default is `{}`. Please make
sure you understand the config for the vector db you are using, otherwise, leave it as `{}`.
Only valid when `vector_db` is a string.
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
default client `chromadb.Client()` will be used. If you want to use other
vector db, extend this class and override the `retrieve_docs` function.
**Deprecated**: use `vector_db` instead.
- `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It
can also be the path to a single file, the url to a single file or a list
of directories, files and urls. Default is None, which works only if the
@ -123,8 +148,11 @@ class RetrieveUserProxyAgent(UserProxyAgent):
By default, "extra_docs" is set to false, starting document IDs from zero.
This poses a risk as new documents might overwrite existing ones, potentially
causing unintended loss or alteration of data in the collection.
- `collection_name` (Optional, str) - the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
**Deprecated**: use `new_docs` when use `vector_db` instead of `client`.
- `new_docs` (Optional, bool) - when True, only adds new documents to the collection;
when False, updates existing documents and adds new ones. Default is True.
Document id is used to determine if a document is new or existing. By default, the
id is the hash value of the content.
- `model` (Optional, str) - the model to use for the retrieve chat.
If key not provided, a default model `gpt-4` will be used.
- `chunk_token_size` (Optional, int) - the chunk token size for the retrieve chat.
@ -143,6 +171,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
models can be found at `https://www.sbert.net/docs/pretrained_models.html`.
The default model is a fast model. If you want to use a high performance model,
`all-mpnet-base-v2` is recommended.
**Deprecated**: no need when use `vector_db` instead of `client`.
- `embedding_function` (Optional, Callable) - the embedding function for creating the
vector db. Default is None, SentenceTransformer with the given `embedding_model`
will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
@ -156,10 +185,14 @@ class RetrieveUserProxyAgent(UserProxyAgent):
`Update Context` will be triggered.
- `update_context` (Optional, bool) - if False, will not apply `Update Context` for
interactive retrieval. Default is True.
- `get_or_create` (Optional, bool) - if True, will create/return a collection for the
retrieve chat. This is the same as that used in chromadb.
Default is False. Will raise ValueError if the collection already exists and
get_or_create is False. Will be set to True if docs_path is None.
- `collection_name` (Optional, str) - the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
- `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is True.
- `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
otherwise it raise a ValueError.
- `custom_token_count_function` (Optional, Callable) - a custom function to count the
number of tokens in a string.
The function should take (text:str, model:str) as input and return the
@ -176,6 +209,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
included files and urls will be chunked regardless of their types.
- `recursive` (Optional, bool) - whether to search documents recursively in the
docs_path. Default is True.
- `distance_threshold` (Optional, float) - the threshold for the distance score, only
distance smaller than it will be returned. Will be ignored if < 0. Default is -1.
`**kwargs` (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
@ -183,6 +218,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
Example of overriding retrieve_docs - If you have set up a customized vector db, and it's
not compatible with chromadb, you can easily plug in it with below code.
**Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
@ -215,9 +251,12 @@ class RetrieveUserProxyAgent(UserProxyAgent):
self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._task = self._retrieve_config.get("task", "default")
self._vector_db = self._retrieve_config.get("vector_db", "chroma")
self._db_config = self._retrieve_config.get("db_config", {})
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", None)
self._extra_docs = self._retrieve_config.get("extra_docs", False)
self._new_docs = self._retrieve_config.get("new_docs", True)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
if "docs_path" not in self._retrieve_config:
logger.warning(
@ -236,6 +275,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
self._overwrite = self._retrieve_config.get("overwrite", False)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
@ -244,18 +284,95 @@ class RetrieveUserProxyAgent(UserProxyAgent):
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
self._results = [] # the results of the current query
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self._current_docs_in_context = [] # the ids of the current context sources
self._search_string = "" # the search string used in the current query
self._distance_threshold = self._retrieve_config.get("distance_threshold", -1)
# update the termination message function
self._is_termination_msg = (
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
)
if isinstance(self._vector_db, str):
if not isinstance(self._db_config, dict):
raise ValueError("`db_config` should be a dictionary.")
if "embedding_function" in self._retrieve_config:
self._db_config["embedding_function"] = self._embedding_function
self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config)
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2)
def _init_db(self):
if not self._vector_db:
return
IS_TO_CHUNK = False # whether to chunk the raw files
if self._new_docs:
IS_TO_CHUNK = True
if not self._docs_path:
try:
self._vector_db.get_collection(self._collection_name)
logger.warning(f"`docs_path` is not provided. Use the existing collection `{self._collection_name}`.")
self._overwrite = False
self._get_or_create = True
IS_TO_CHUNK = False
except ValueError:
raise ValueError(
"`docs_path` is not provided. "
f"The collection `{self._collection_name}` doesn't exist either. "
"Please provide `docs_path` or create the collection first."
)
elif self._get_or_create and not self._overwrite:
try:
self._vector_db.get_collection(self._collection_name)
logger.info(f"Use the existing collection `{self._collection_name}`.", color="green")
except ValueError:
IS_TO_CHUNK = True
else:
IS_TO_CHUNK = True
self._vector_db.active_collection = self._vector_db.create_collection(
self._collection_name, overwrite=self._overwrite, get_or_create=self._get_or_create
)
docs = None
if IS_TO_CHUNK:
if self.custom_text_split_function is not None:
chunks, sources = split_files_to_chunks(
get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
custom_text_split_function=self.custom_text_split_function,
)
else:
chunks, sources = split_files_to_chunks(
get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
self._max_tokens,
self._chunk_mode,
self._must_break_at_empty_line,
)
logger.info(f"Found {len(chunks)} chunks.")
if self._new_docs:
all_docs_ids = set(
[
doc["id"]
for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name)
]
)
else:
all_docs_ids = set()
chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
chunk_ids_set = set(chunk_ids)
chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
docs = [
Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx])
for idx in chunk_ids_set_idx
if chunk_ids[idx] not in all_docs_ids
]
self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True)
def _is_termination_msg_retrievechat(self, message):
"""Check if a message is a termination message.
For code generation, terminate when no code block is detected. Currently only detect python code blocks.
@ -288,41 +405,42 @@ class RetrieveUserProxyAgent(UserProxyAgent):
def _reset(self, intermediate=False):
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
self._results = [] # the results of the current query
if not intermediate:
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
def _get_context(self, results: QueryResults):
doc_contents = ""
self._current_docs_in_context = []
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
for idx, doc in enumerate(results["documents"][0]):
for idx, doc in enumerate(results[0]):
doc = doc[0]
if idx <= _doc_idx:
continue
if results["ids"][0][idx] in self._doc_ids:
if doc["id"] in self._doc_ids:
continue
_doc_tokens = self.custom_token_count_function(doc, self._model)
_doc_tokens = self.custom_token_count_function(doc["content"], self._model)
if _doc_tokens > self._context_max_tokens:
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
func_print = f"Skip doc_id {doc['id']} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
self._doc_idx = idx
continue
if current_tokens + _doc_tokens > self._context_max_tokens:
break
func_print = f"Adding doc_id {results['ids'][0][idx]} to context."
func_print = f"Adding content of doc {doc['id']} to context."
print(colored(func_print, "green"), flush=True)
current_tokens += _doc_tokens
doc_contents += doc + "\n"
_metadatas = results.get("metadatas")
if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict):
self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", ""))
doc_contents += doc["content"] + "\n"
_metadata = doc.get("metadata")
if isinstance(_metadata, dict):
self._current_docs_in_context.append(_metadata.get("source", ""))
self._doc_idx = idx
self._doc_ids.append(results["ids"][0][idx])
self._doc_contents.append(doc)
self._doc_ids.append(doc["id"])
self._doc_contents.append(doc["content"])
_tmp_retrieve_count += 1
if _tmp_retrieve_count >= self.n_results:
break
@ -416,21 +534,40 @@ class RetrieveUserProxyAgent(UserProxyAgent):
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
parameters and add your own parameters with default values, and keep the results in below type.
Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
to `chromadb.api.types.QueryResult` as an example.
ids: List[string]
documents: List[List[string]]
The retrieved docs should be type of `QueryResults` which is a list of tuples containing the document and
the distance.
Args:
problem (str): the problem to be solved.
n_results (int): the number of results to be retrieved. Default is 20.
search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "".
Not used if the vector_db doesn't support it.
Returns:
None.
"""
if isinstance(self._vector_db, VectorDB):
if not self._collection or not self._get_or_create:
print("Trying to create collection.")
self._init_db()
self._collection = True
self._get_or_create = True
kwargs = {}
if hasattr(self._vector_db, "type") and self._vector_db.type == "chroma":
kwargs["where_document"] = {"$contains": search_string} if search_string else None
results = self._vector_db.retrieve_docs(
queries=[problem],
n_results=n_results,
collection_name=self._collection_name,
distance_threshold=self._distance_threshold,
**kwargs,
)
self._search_string = search_string
self._results = results
print("VectorDB returns doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
return
if not self._collection or not self._get_or_create:
print("Trying to create collection.")
self._client = create_vector_db_from_dir(
@ -460,9 +597,13 @@ class RetrieveUserProxyAgent(UserProxyAgent):
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
)
results["contents"] = results.pop("documents")
results = chroma_results_to_query_results(results, "distances")
results = filter_results_by_distance(results, self._distance_threshold)
self._search_string = search_string
self._results = results
print("doc_ids: ", results["ids"])
print("doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
@staticmethod
def message_generator(sender, recipient, context):

View File

@ -24,7 +24,7 @@ class ChromaVectorDB(VectorDB):
"""
def __init__(
self, *, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs
self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs
) -> None:
"""
Initialize the vector database.
@ -32,7 +32,7 @@ class ChromaVectorDB(VectorDB):
Args:
client: chromadb.Client | The client object of the vector database. Default is None.
If provided, it will use the client object directly and ignore other arguments.
path: str | The path to the vector database. Default is None.
path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24.
embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this

View File

@ -25,6 +25,9 @@ class ColoredLogger(logging.Logger):
def critical(self, msg, *args, color="red", **kwargs):
super().critical(colored(msg, color), *args, **kwargs)
def fatal(self, msg, *args, color="red", **kwargs):
super().fatal(colored(msg, color), *args, **kwargs)
def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger:
logger = ColoredLogger(name, level)

View File

@ -1,4 +1,5 @@
import glob
import hashlib
import os
import re
from typing import Callable, List, Tuple, Union
@ -156,7 +157,7 @@ def split_files_to_chunks(
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
custom_text_split_function: Callable = None,
):
) -> Tuple[List[str], List[dict]]:
"""Split a list of files into chunks of max_tokens."""
chunks = []
@ -275,15 +276,22 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
return webpage_text
def _generate_file_name_from_url(url: str, max_length=255) -> str:
url_bytes = url.encode("utf-8")
hash = hashlib.blake2b(url_bytes).hexdigest()
parsed_url = urlparse(url)
file_name = os.path.basename(url)
file_name = f"{parsed_url.netloc}_{file_name}_{hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}"
return file_name
def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
"""Download a file from a URL."""
if save_path is None:
save_path = "tmp/chromadb"
os.makedirs(save_path, exist_ok=True)
if os.path.isdir(save_path):
filename = os.path.basename(url)
if filename == "": # "www.example.com/"
filename = url.split("/")[-2]
filename = _generate_file_name_from_url(url)
save_path = os.path.join(save_path, filename)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
@ -327,7 +335,7 @@ def create_vector_db_from_dir(
dir_path: Union[str, List[str]],
max_tokens: int = 4000,
client: API = None,
db_path: str = "/tmp/chromadb.db",
db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
get_or_create: bool = False,
chunk_mode: str = "multi_lines",
@ -347,7 +355,7 @@ def create_vector_db_from_dir(
dir_path (Union[str, List[str]]): the path to the directory, file, url or a list of them.
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
client (Optional, API): the chromadb client. Default is None.
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
db_path (Optional, str): the path to the chromadb. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
will be returned if it already exists. Will raise ValueError if the collection already exists and get_or_create is False.
@ -420,7 +428,7 @@ def query_vector_db(
query_texts: List[str],
n_results: int = 10,
client: API = None,
db_path: str = "/tmp/chromadb.db",
db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
@ -433,7 +441,7 @@ def query_vector_db(
query_texts (List[str]): the list of strings which will be used to query the vector db.
n_results (Optional, int): the number of results to return. Default is 10.
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
db_path (Optional, str): the path to the vector db. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
search_string (Optional, str): the search string. Only docs that contain an exact match of this string will be retrieved. Default is "".
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long