Add support to custom text spliter (#270)

* Add support to custom text spliter function and a list of files or urls

* Add parameter to retrieve_config, add tests

* Fix tests

* Fix tests
This commit is contained in:
Li Jiang 2023-10-17 22:53:40 +08:00 committed by GitHub
parent 294e006ac9
commit f2d7553cdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 5 deletions

View File

@ -122,6 +122,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
Example of overriding retrieve_docs:
@ -175,6 +177,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
@ -364,6 +367,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
embedding_model=self._embedding_model,
get_or_create=self._get_or_create,
embedding_function=self._embedding_function,
custom_text_split_function=self.custom_text_split_function,
)
self._collection = True
self._get_or_create = False

View File

@ -180,7 +180,11 @@ def extract_text_from_pdf(file: str) -> str:
def split_files_to_chunks(
files: list, max_tokens: int = 4000, chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True
files: list,
max_tokens: int = 4000,
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
custom_text_split_function: Callable = None,
):
"""Split a list of files into chunks of max_tokens."""
@ -200,18 +204,33 @@ def split_files_to_chunks(
logger.warning(f"No text available in file: {file}")
continue # Skip to the next file if no text is available
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
if custom_text_split_function is not None:
chunks += custom_text_split_function(text)
else:
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
return chunks
def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: bool = True):
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
"""Return a list of all the files in a given directory."""
if len(types) == 0:
raise ValueError("types cannot be empty.")
types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)]
types += [t.upper() for t in types]
files = []
# If the path is a list of files or urls, process and return them
if isinstance(dir_path, list):
for item in dir_path:
if os.path.isfile(item):
files.append(item)
elif is_url(item):
files.append(get_file_from_url(item))
else:
logger.warning(f"File {item} does not exist. Skipping.")
return files
# If the path is a file, return it
if os.path.isfile(dir_path):
return [dir_path]
@ -220,7 +239,6 @@ def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: boo
if is_url(dir_path):
return [get_file_from_url(dir_path)]
files = []
if os.path.exists(dir_path):
for type in types:
if recursive:
@ -265,6 +283,7 @@ def create_vector_db_from_dir(
must_break_at_empty_line: bool = True,
embedding_model: str = "all-MiniLM-L6-v2",
embedding_function: Callable = None,
custom_text_split_function: Callable = None,
):
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
@ -304,7 +323,14 @@ def create_vector_db_from_dir(
metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine
)
chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
if custom_text_split_function is not None:
chunks = split_files_to_chunks(
get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function
)
else:
chunks = split_files_to_chunks(
get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line
)
logger.info(f"Found {len(chunks)} chunks.")
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
for i in range(0, len(chunks), min(40000, len(chunks))):

View File

@ -74,6 +74,10 @@ class TestRetrieveUtils:
def test_get_files_from_dir(self):
files = get_files_from_dir(test_dir)
assert all(os.path.isfile(file) for file in files)
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
files = get_files_from_dir([pdf_file_path, txt_file_path])
assert all(os.path.isfile(file) for file in files)
def test_is_url(self):
assert is_url("https://www.example.com")
@ -164,6 +168,24 @@ class TestRetrieveUtils:
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
assert ragragproxyagent._results["ids"] == [3, 1, 5]
def test_custom_text_split_function(self):
def custom_text_split_function(text):
return [text[: len(text) // 2], text[len(text) // 2 :]]
db_path = "/tmp/test_retrieve_utils_chromadb.db"
client = chromadb.PersistentClient(path=db_path)
create_vector_db_from_dir(
os.path.join(test_dir, "example.txt"),
client=client,
collection_name="mytestcollection",
custom_text_split_function=custom_text_split_function,
)
results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1)
assert (
results.get("documents")[0][0]
== "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities\nof Large Language Models (LLMs) for various applications. The primary purpose o"
)
if __name__ == "__main__":
pytest.main()