Add source to the answer for default prompt (#2289)

* Add source to the answer for default prompt

* Fix qdrant

* Fix tests

* Update docstring

* Fix check files

* Fix qdrant test error
This commit is contained in:
Li Jiang 2024-04-10 08:45:26 +08:00 committed by GitHub
parent 5292024839
commit 5a96dc2c29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 17 deletions

View File

@ -190,12 +190,12 @@ def create_qdrant_from_dir(
client.set_model(embedding_model)
if custom_text_split_function is not None:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
)
logger.info(f"Found {len(chunks)} chunks.")
@ -298,5 +298,6 @@ def query_qdrant(
data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
"metadatas": [[result.metadata for result in sublist] for sublist in results],
}
return data

View File

@ -34,6 +34,10 @@ If user's intent is question answering, you must give as short an answer as poss
User's question is: {input_question}
Context is: {input_context}
The source of the context is: {input_sources}
If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`.
"""
PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the
@ -101,7 +105,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
following keys:
- `task` (Optional, str) - the task of the retrieve chat. Possible values are
"code", "qa" and "default". System prompt will be different for different tasks.
The default value is `default`, which supports both code and qa.
The default value is `default`, which supports both code and qa, and provides
source information in the end of the response.
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
default client `chromadb.Client()` will be used. If you want to use other
vector db, extend this class and override the `retrieve_docs` function.
@ -243,6 +248,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self._current_docs_in_context = [] # the ids of the current context sources
self._search_string = "" # the search string used in the current query
# update the termination message function
self._is_termination_msg = (
@ -290,6 +296,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
doc_contents = ""
self._current_docs_in_context = []
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
@ -310,6 +317,9 @@ class RetrieveUserProxyAgent(UserProxyAgent):
print(colored(func_print, "green"), flush=True)
current_tokens += _doc_tokens
doc_contents += doc + "\n"
_metadatas = results.get("metadatas")
if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict):
self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", ""))
self._doc_idx = idx
self._doc_ids.append(results["ids"][0][idx])
self._doc_contents.append(doc)
@ -329,7 +339,9 @@ class RetrieveUserProxyAgent(UserProxyAgent):
elif task.upper() == "QA":
message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "DEFAULT":
message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents)
message = PROMPT_DEFAULT.format(
input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context
)
else:
raise NotImplementedError(f"task {task} is not implemented.")
return message

View File

@ -1,7 +1,7 @@
import glob
import os
import re
from typing import Callable, List, Union
from typing import Callable, List, Tuple, Union
from urllib.parse import urlparse
import chromadb
@ -160,8 +160,14 @@ def split_files_to_chunks(
"""Split a list of files into chunks of max_tokens."""
chunks = []
sources = []
for file in files:
if isinstance(file, tuple):
url = file[1]
file = file[0]
else:
url = None
_, file_extension = os.path.splitext(file)
file_extension = file_extension.lower()
@ -179,11 +185,13 @@ def split_files_to_chunks(
continue # Skip to the next file if no text is available
if custom_text_split_function is not None:
chunks += custom_text_split_function(text)
tmp_chunks = custom_text_split_function(text)
else:
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
tmp_chunks = split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
chunks += tmp_chunks
sources += [{"source": url if url else file}] * len(tmp_chunks)
return chunks
return chunks, sources
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
@ -267,7 +275,7 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
return webpage_text
def get_file_from_url(url: str, save_path: str = None):
def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
"""Download a file from a URL."""
if save_path is None:
save_path = "tmp/chromadb"
@ -303,7 +311,7 @@ def get_file_from_url(url: str, save_path: str = None):
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return save_path
return save_path, url
def is_url(string: str):
@ -383,12 +391,12 @@ def create_vector_db_from_dir(
length = len(collection.get()["ids"])
if custom_text_split_function is not None:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
chunks = split_files_to_chunks(
chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
max_tokens,
chunk_mode,
@ -401,6 +409,7 @@ def create_vector_db_from_dir(
collection.upsert(
documents=chunks[i:end_idx],
ids=[f"doc_{j+length}" for j in range(i, end_idx)], # unique for each doc
metadatas=sources[i:end_idx],
)
except ValueError as e:
logger.warning(f"{e}")

View File

@ -69,7 +69,7 @@ class TestRetrieveUtils:
def test_split_files_to_chunks(self):
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
chunks = split_files_to_chunks([pdf_file_path, txt_file_path])
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path])
assert all(
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
for chunk in chunks
@ -81,7 +81,7 @@ class TestRetrieveUtils:
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
files = get_files_from_dir([pdf_file_path, txt_file_path])
assert all(os.path.isfile(file) for file in files)
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
files = get_files_from_dir(
[
pdf_file_path,
@ -91,7 +91,7 @@ class TestRetrieveUtils:
],
recursive=True,
)
assert all(os.path.isfile(file) for file in files)
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
files = get_files_from_dir(
[
pdf_file_path,
@ -102,7 +102,7 @@ class TestRetrieveUtils:
recursive=True,
types=["pdf", "txt"],
)
assert all(os.path.isfile(file) for file in files)
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
assert len(files) == 3
def test_is_url(self):
@ -243,7 +243,7 @@ class TestRetrieveUtils:
pdf_file_path = os.path.join(test_dir, "example.pdf")
txt_file_path = os.path.join(test_dir, "example.txt")
word_file_path = os.path.join(test_dir, "example.docx")
chunks = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
assert all(
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
for chunk in chunks