mirror of https://github.com/microsoft/autogen.git
Add support to custom text spliter (#270)
* Add support to custom text spliter function and a list of files or urls * Add parameter to retrieve_config, add tests * Fix tests * Fix tests
This commit is contained in:
parent
294e006ac9
commit
f2d7553cdc
|
@ -122,6 +122,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
|
||||
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
|
||||
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
|
||||
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
|
||||
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
|
||||
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
|
||||
|
||||
Example of overriding retrieve_docs:
|
||||
|
@ -175,6 +177,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
|
||||
)
|
||||
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
|
||||
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
|
||||
self._context_max_tokens = self._max_tokens * 0.8
|
||||
self._collection = True if self._docs_path is None else False # whether the collection is created
|
||||
self._ipython = get_ipython()
|
||||
|
@ -364,6 +367,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
embedding_model=self._embedding_model,
|
||||
get_or_create=self._get_or_create,
|
||||
embedding_function=self._embedding_function,
|
||||
custom_text_split_function=self.custom_text_split_function,
|
||||
)
|
||||
self._collection = True
|
||||
self._get_or_create = False
|
||||
|
|
|
@ -180,7 +180,11 @@ def extract_text_from_pdf(file: str) -> str:
|
|||
|
||||
|
||||
def split_files_to_chunks(
|
||||
files: list, max_tokens: int = 4000, chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True
|
||||
files: list,
|
||||
max_tokens: int = 4000,
|
||||
chunk_mode: str = "multi_lines",
|
||||
must_break_at_empty_line: bool = True,
|
||||
custom_text_split_function: Callable = None,
|
||||
):
|
||||
"""Split a list of files into chunks of max_tokens."""
|
||||
|
||||
|
@ -200,18 +204,33 @@ def split_files_to_chunks(
|
|||
logger.warning(f"No text available in file: {file}")
|
||||
continue # Skip to the next file if no text is available
|
||||
|
||||
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
|
||||
if custom_text_split_function is not None:
|
||||
chunks += custom_text_split_function(text)
|
||||
else:
|
||||
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: bool = True):
|
||||
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
|
||||
"""Return a list of all the files in a given directory."""
|
||||
if len(types) == 0:
|
||||
raise ValueError("types cannot be empty.")
|
||||
types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)]
|
||||
types += [t.upper() for t in types]
|
||||
|
||||
files = []
|
||||
# If the path is a list of files or urls, process and return them
|
||||
if isinstance(dir_path, list):
|
||||
for item in dir_path:
|
||||
if os.path.isfile(item):
|
||||
files.append(item)
|
||||
elif is_url(item):
|
||||
files.append(get_file_from_url(item))
|
||||
else:
|
||||
logger.warning(f"File {item} does not exist. Skipping.")
|
||||
return files
|
||||
|
||||
# If the path is a file, return it
|
||||
if os.path.isfile(dir_path):
|
||||
return [dir_path]
|
||||
|
@ -220,7 +239,6 @@ def get_files_from_dir(dir_path: str, types: list = TEXT_FORMATS, recursive: boo
|
|||
if is_url(dir_path):
|
||||
return [get_file_from_url(dir_path)]
|
||||
|
||||
files = []
|
||||
if os.path.exists(dir_path):
|
||||
for type in types:
|
||||
if recursive:
|
||||
|
@ -265,6 +283,7 @@ def create_vector_db_from_dir(
|
|||
must_break_at_empty_line: bool = True,
|
||||
embedding_model: str = "all-MiniLM-L6-v2",
|
||||
embedding_function: Callable = None,
|
||||
custom_text_split_function: Callable = None,
|
||||
):
|
||||
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
|
||||
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
|
||||
|
@ -304,7 +323,14 @@ def create_vector_db_from_dir(
|
|||
metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine
|
||||
)
|
||||
|
||||
chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
|
||||
if custom_text_split_function is not None:
|
||||
chunks = split_files_to_chunks(
|
||||
get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function
|
||||
)
|
||||
else:
|
||||
chunks = split_files_to_chunks(
|
||||
get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line
|
||||
)
|
||||
logger.info(f"Found {len(chunks)} chunks.")
|
||||
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
|
||||
for i in range(0, len(chunks), min(40000, len(chunks))):
|
||||
|
|
|
@ -74,6 +74,10 @@ class TestRetrieveUtils:
|
|||
def test_get_files_from_dir(self):
|
||||
files = get_files_from_dir(test_dir)
|
||||
assert all(os.path.isfile(file) for file in files)
|
||||
pdf_file_path = os.path.join(test_dir, "example.pdf")
|
||||
txt_file_path = os.path.join(test_dir, "example.txt")
|
||||
files = get_files_from_dir([pdf_file_path, txt_file_path])
|
||||
assert all(os.path.isfile(file) for file in files)
|
||||
|
||||
def test_is_url(self):
|
||||
assert is_url("https://www.example.com")
|
||||
|
@ -164,6 +168,24 @@ class TestRetrieveUtils:
|
|||
ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark")
|
||||
assert ragragproxyagent._results["ids"] == [3, 1, 5]
|
||||
|
||||
def test_custom_text_split_function(self):
|
||||
def custom_text_split_function(text):
|
||||
return [text[: len(text) // 2], text[len(text) // 2 :]]
|
||||
|
||||
db_path = "/tmp/test_retrieve_utils_chromadb.db"
|
||||
client = chromadb.PersistentClient(path=db_path)
|
||||
create_vector_db_from_dir(
|
||||
os.path.join(test_dir, "example.txt"),
|
||||
client=client,
|
||||
collection_name="mytestcollection",
|
||||
custom_text_split_function=custom_text_split_function,
|
||||
)
|
||||
results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1)
|
||||
assert (
|
||||
results.get("documents")[0][0]
|
||||
== "AutoGen is an advanced tool designed to assist developers in harnessing the capabilities\nof Large Language Models (LLMs) for various applications. The primary purpose o"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
Loading…
Reference in New Issue