mirror of https://github.com/microsoft/autogen.git
Add source to the answer for default prompt (#2289)
* Add source to the answer for default prompt * Fix qdrant * Fix tests * Update docstring * Fix check files * Fix qdrant test error
This commit is contained in:
parent
5292024839
commit
5a96dc2c29
|
@ -190,12 +190,12 @@ def create_qdrant_from_dir(
|
|||
client.set_model(embedding_model)
|
||||
|
||||
if custom_text_split_function is not None:
|
||||
chunks = split_files_to_chunks(
|
||||
chunks, sources = split_files_to_chunks(
|
||||
get_files_from_dir(dir_path, custom_text_types, recursive),
|
||||
custom_text_split_function=custom_text_split_function,
|
||||
)
|
||||
else:
|
||||
chunks = split_files_to_chunks(
|
||||
chunks, sources = split_files_to_chunks(
|
||||
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
|
||||
)
|
||||
logger.info(f"Found {len(chunks)} chunks.")
|
||||
|
@ -298,5 +298,6 @@ def query_qdrant(
|
|||
data = {
|
||||
"ids": [[result.id for result in sublist] for sublist in results],
|
||||
"documents": [[result.document for result in sublist] for sublist in results],
|
||||
"metadatas": [[result.metadata for result in sublist] for sublist in results],
|
||||
}
|
||||
return data
|
||||
|
|
|
@ -34,6 +34,10 @@ If user's intent is question answering, you must give as short an answer as poss
|
|||
User's question is: {input_question}
|
||||
|
||||
Context is: {input_context}
|
||||
|
||||
The source of the context is: {input_sources}
|
||||
|
||||
If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`.
|
||||
"""
|
||||
|
||||
PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the
|
||||
|
@ -101,7 +105,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
following keys:
|
||||
- `task` (Optional, str) - the task of the retrieve chat. Possible values are
|
||||
"code", "qa" and "default". System prompt will be different for different tasks.
|
||||
The default value is `default`, which supports both code and qa.
|
||||
The default value is `default`, which supports both code and qa, and provides
|
||||
source information in the end of the response.
|
||||
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
|
||||
default client `chromadb.Client()` will be used. If you want to use other
|
||||
vector db, extend this class and override the `retrieve_docs` function.
|
||||
|
@ -243,6 +248,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
self._intermediate_answers = set() # the intermediate answers
|
||||
self._doc_contents = [] # the contents of the current used doc
|
||||
self._doc_ids = [] # the ids of the current used doc
|
||||
self._current_docs_in_context = [] # the ids of the current context sources
|
||||
self._search_string = "" # the search string used in the current query
|
||||
# update the termination message function
|
||||
self._is_termination_msg = (
|
||||
|
@ -290,6 +296,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
|
||||
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
|
||||
doc_contents = ""
|
||||
self._current_docs_in_context = []
|
||||
current_tokens = 0
|
||||
_doc_idx = self._doc_idx
|
||||
_tmp_retrieve_count = 0
|
||||
|
@ -310,6 +317,9 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
print(colored(func_print, "green"), flush=True)
|
||||
current_tokens += _doc_tokens
|
||||
doc_contents += doc + "\n"
|
||||
_metadatas = results.get("metadatas")
|
||||
if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict):
|
||||
self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", ""))
|
||||
self._doc_idx = idx
|
||||
self._doc_ids.append(results["ids"][0][idx])
|
||||
self._doc_contents.append(doc)
|
||||
|
@ -329,7 +339,9 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
elif task.upper() == "QA":
|
||||
message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents)
|
||||
elif task.upper() == "DEFAULT":
|
||||
message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents)
|
||||
message = PROMPT_DEFAULT.format(
|
||||
input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"task {task} is not implemented.")
|
||||
return message
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import glob
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import chromadb
|
||||
|
@ -160,8 +160,14 @@ def split_files_to_chunks(
|
|||
"""Split a list of files into chunks of max_tokens."""
|
||||
|
||||
chunks = []
|
||||
sources = []
|
||||
|
||||
for file in files:
|
||||
if isinstance(file, tuple):
|
||||
url = file[1]
|
||||
file = file[0]
|
||||
else:
|
||||
url = None
|
||||
_, file_extension = os.path.splitext(file)
|
||||
file_extension = file_extension.lower()
|
||||
|
||||
|
@ -179,11 +185,13 @@ def split_files_to_chunks(
|
|||
continue # Skip to the next file if no text is available
|
||||
|
||||
if custom_text_split_function is not None:
|
||||
chunks += custom_text_split_function(text)
|
||||
tmp_chunks = custom_text_split_function(text)
|
||||
else:
|
||||
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
|
||||
tmp_chunks = split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
|
||||
chunks += tmp_chunks
|
||||
sources += [{"source": url if url else file}] * len(tmp_chunks)
|
||||
|
||||
return chunks
|
||||
return chunks, sources
|
||||
|
||||
|
||||
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
|
||||
|
@ -267,7 +275,7 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
|
|||
return webpage_text
|
||||
|
||||
|
||||
def get_file_from_url(url: str, save_path: str = None):
|
||||
def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
|
||||
"""Download a file from a URL."""
|
||||
if save_path is None:
|
||||
save_path = "tmp/chromadb"
|
||||
|
@ -303,7 +311,7 @@ def get_file_from_url(url: str, save_path: str = None):
|
|||
with open(save_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
return save_path
|
||||
return save_path, url
|
||||
|
||||
|
||||
def is_url(string: str):
|
||||
|
@ -383,12 +391,12 @@ def create_vector_db_from_dir(
|
|||
length = len(collection.get()["ids"])
|
||||
|
||||
if custom_text_split_function is not None:
|
||||
chunks = split_files_to_chunks(
|
||||
chunks, sources = split_files_to_chunks(
|
||||
get_files_from_dir(dir_path, custom_text_types, recursive),
|
||||
custom_text_split_function=custom_text_split_function,
|
||||
)
|
||||
else:
|
||||
chunks = split_files_to_chunks(
|
||||
chunks, sources = split_files_to_chunks(
|
||||
get_files_from_dir(dir_path, custom_text_types, recursive),
|
||||
max_tokens,
|
||||
chunk_mode,
|
||||
|
@ -401,6 +409,7 @@ def create_vector_db_from_dir(
|
|||
collection.upsert(
|
||||
documents=chunks[i:end_idx],
|
||||
ids=[f"doc_{j+length}" for j in range(i, end_idx)], # unique for each doc
|
||||
metadatas=sources[i:end_idx],
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"{e}")
|
||||
|
|
|
@ -69,7 +69,7 @@ class TestRetrieveUtils:
|
|||
def test_split_files_to_chunks(self):
|
||||
pdf_file_path = os.path.join(test_dir, "example.pdf")
|
||||
txt_file_path = os.path.join(test_dir, "example.txt")
|
||||
chunks = split_files_to_chunks([pdf_file_path, txt_file_path])
|
||||
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path])
|
||||
assert all(
|
||||
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
|
||||
for chunk in chunks
|
||||
|
@ -81,7 +81,7 @@ class TestRetrieveUtils:
|
|||
pdf_file_path = os.path.join(test_dir, "example.pdf")
|
||||
txt_file_path = os.path.join(test_dir, "example.txt")
|
||||
files = get_files_from_dir([pdf_file_path, txt_file_path])
|
||||
assert all(os.path.isfile(file) for file in files)
|
||||
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
|
||||
files = get_files_from_dir(
|
||||
[
|
||||
pdf_file_path,
|
||||
|
@ -91,7 +91,7 @@ class TestRetrieveUtils:
|
|||
],
|
||||
recursive=True,
|
||||
)
|
||||
assert all(os.path.isfile(file) for file in files)
|
||||
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
|
||||
files = get_files_from_dir(
|
||||
[
|
||||
pdf_file_path,
|
||||
|
@ -102,7 +102,7 @@ class TestRetrieveUtils:
|
|||
recursive=True,
|
||||
types=["pdf", "txt"],
|
||||
)
|
||||
assert all(os.path.isfile(file) for file in files)
|
||||
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
|
||||
assert len(files) == 3
|
||||
|
||||
def test_is_url(self):
|
||||
|
@ -243,7 +243,7 @@ class TestRetrieveUtils:
|
|||
pdf_file_path = os.path.join(test_dir, "example.pdf")
|
||||
txt_file_path = os.path.join(test_dir, "example.txt")
|
||||
word_file_path = os.path.join(test_dir, "example.docx")
|
||||
chunks = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
|
||||
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
|
||||
assert all(
|
||||
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
|
||||
for chunk in chunks
|
||||
|
|
Loading…
Reference in New Issue