mirror of https://github.com/microsoft/autogen.git
Add support to custom text spliter (#270)
* Add support to custom text spliter function and a list of files or urls * Add parameter to retrieve_config, add tests * Fix tests * Fix tests
This commit is contained in:
parent
294e006ac9
commit
f2d7553cdc
|
@ -122,6 +122,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||||
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
|
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
|
||||||
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
|
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
|
||||||
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
|
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
|
||||||
|
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
|
||||||
|
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
|
||||||
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
|
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
|
||||||
|
|
||||||
Example of overriding retrieve_docs:
|
Example of overriding retrieve_docs:
|
||||||
|
@ -175,6 +177,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||||
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
|
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
|
||||||
)
|
)
|
||||||
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
|
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
|
||||||
|
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
|
||||||
self._context_max_tokens = self._max_tokens * 0.8
|
self._context_max_tokens = self._max_tokens * 0.8
|
||||||
self._collection = True if self._docs_path is None else False # whether the collection is created
|
self._collection = True if self._docs_path is None else False # whether the collection is created
|
||||||
self._ipython = get_ipython()
|
self._ipython = get_ipython()
|
||||||
|
@ -364,6 +367,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||||
embedding_model=self._embedding_model,
|
embedding_model=self._embedding_model,
|
||||||
get_or_create=self._get_or_create,
|
get_or_create=self._get_or_create,
|
||||||
embedding_function=self._embedding_function,
|
embedding_function=self._embedding_function,
|
||||||
|
custom_text_split_function=self.custom_text_split_function,
|
||||||
)
|
)
|
||||||
self._collection = True
|
self._collection = True
|
||||||
self._get_or_create = False
|
self._get_or_create = False
|
||||||
|
|
|
@ -180,7 +180,11 @@ def extract_text_from_pdf(file: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def split_files_to_chunks(
|
def split_files_to_chunks(
|
||||||
files: list, max_tokens: int = 4000, chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True
|
files: list,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
chunk_mode: str = "multi_lines",
|
||||||
|
must_break_at_empty_line: bool = True,
|
||||||
|
custom_text_split_function: Callable = None,
|
||||||
):
|
):
|
||||||
"""Split a list of files into chunks of max_tokens."""
|
"""Split a list of files into chunks of max_tokens."""
|
||||||
|
|
||||||
|
@ -200,18 +204,33 @@ def split_files_to_chunks(
|
||||||
logger.warning(f"No text available in file: {file}")
|
logger.warning(f"No text available in file: {file}")
|
||||||
continue # Skip to the next file if no text is available
|
continue # Skip to the next file if no text is available
|
||||||
|
|
||||||
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
|
if custom_text_split_function is not None:
|
||||||
|
chunks += custom_text_split_function(text)
|
||||||
|
else:
|
||||||
|
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: bool = True):
|
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
|
||||||
"""Return a list of all the files in a given directory."""
|
"""Return a list of all the files in a given directory."""
|
||||||
if len(types) == 0:
|
if len(types) == 0:
|
||||||
raise ValueError("types cannot be empty.")
|
raise ValueError("types cannot be empty.")
|
||||||
types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)]
|
types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)]
|
||||||
types += [t.upper() for t in types]
|
types += [t.upper() for t in types]
|
||||||
|
|
||||||
|
files = []
|
||||||
|
# If the path is a list of files or urls, process and return them
|
||||||
|
if isinstance(dir_path, list):
|
||||||
|
for item in dir_path:
|
||||||
|
if os.path.isfile(item):
|
||||||
|
files.append(item)
|
||||||
|
elif is_url(item):
|
||||||
|
files.append(get_file_from_url(item))
|
||||||
|
else:
|
||||||
|
logger.warning(f"File {item} does not exist. Skipping.")
|
||||||
|
return files
|
||||||
|
|
||||||
# If the path is a file, return it
|
# If the path is a file, return it
|
||||||
if os.path.isfile(dir_path):
|
if os.path.isfile(dir_path):
|
||||||
return [dir_path]
|
return [dir_path]
|
||||||
|
@ -220,7 +239,6 @@ def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: boo
|
||||||
if is_url(dir_path):
|
if is_url(dir_path):
|
||||||
return [get_file_from_url(dir_path)]
|
return [get_file_from_url(dir_path)]
|
||||||
|
|
||||||
files = []
|
|
||||||
if os.path.exists(dir_path):
|
if os.path.exists(dir_path):
|
||||||
for type in types:
|
for type in types:
|
||||||
if recursive:
|
if recursive:
|
||||||
|
@ -265,6 +283,7 @@ def create_vector_db_from_dir(
|
||||||
must_break_at_empty_line: bool = True,
|
must_break_at_empty_line: bool = True,
|
||||||
embedding_model: str = "all-MiniLM-L6-v2",
|
embedding_model: str = "all-MiniLM-L6-v2",
|
||||||
embedding_function: Callable = None,
|
embedding_function: Callable = None,
|
||||||
|
custom_text_split_function: Callable = None,
|
||||||
):
|
):
|
||||||
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
|
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
|
||||||
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
|
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
|
||||||
|
@ -304,7 +323,14 @@ def create_vector_db_from_dir(
|
||||||
metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine
|
metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
|
if custom_text_split_function is not None:
|
||||||
|
chunks = split_files_to_chunks(
|
||||||
|
get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunks = split_files_to_chunks(
|
||||||
|
get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line
|
||||||
|
)
|
||||||
logger.info(f"Found {len(chunks)} chunks.")
|
logger.info(f"Found {len(chunks)} chunks.")
|
||||||
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
|
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
|
||||||
for i in range(0, len(chunks), min(40000, len(chunks))):
|
for i in range(0, len(chunks), min(40000, len(chunks))):
|
||||||
|
|
|
@ -74,6 +74,10 @@ class TestRetrieveUtils:
|
||||||
def test_get_files_from_dir(self):
|
def test_get_files_from_dir(self):
|
||||||
files = get_files_from_dir(test_dir)
|
files = get_files_from_dir(test_dir)
|
||||||
assert all(os.path.isfile(file) for file in files)
|
assert all(os.path.isfile(file) for file in files)
|
||||||
|
pdf_file_path = os.path.join(test_dir, "example.pdf")
|
||||||
|
txt_file_path = os.path.join(test_dir, "example.txt")
|
||||||
|
files = get_files_from_dir([pdf_file_path, txt_file_path])
|
||||||
|
assert all(os.path.isfile(file) for file in files)
|
||||||
|
|
||||||
def test_is_url(self):
|
def test_is_url(self):
|
||||||
assert is_url("https://www.example.com")
|
assert is_url("https://www.example.com")
|
||||||
|
@ -164,6 +168,24 @@ class TestRetrieveUtils:
|
||||||
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
|
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
|
||||||
assert ragragproxyagent._results["ids"] == [3, 1, 5]
|
assert ragragproxyagent._results["ids"] == [3, 1, 5]
|
||||||
|
|
||||||
|
def test_custom_text_split_function(self):
|
||||||
|
def custom_text_split_function(text):
|
||||||
|
return [text[: len(text) // 2], text[len(text) // 2 :]]
|
||||||
|
|
||||||
|
db_path = "/tmp/test_retrieve_utils_chromadb.db"
|
||||||
|
client = chromadb.PersistentClient(path=db_path)
|
||||||
|
create_vector_db_from_dir(
|
||||||
|
os.path.join(test_dir, "example.txt"),
|
||||||
|
client=client,
|
||||||
|
collection_name="mytestcollection",
|
||||||
|
custom_text_split_function=custom_text_split_function,
|
||||||
|
)
|
||||||
|
results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1)
|
||||||
|
assert (
|
||||||
|
results.get("documents")[0][0]
|
||||||
|
== "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities\nof Large Language Models (LLMs) for various applications. The primary purpose o"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main()
|
pytest.main()
|
||||||
|
|
Loading…
Reference in New Issue