From c4e570393db9d2c2d3058e271ca2e46473bd8074 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 17 Apr 2024 16:30:05 +0800 Subject: [PATCH] Support setting vector_db as a param (#2313) * Added vectordb base and chromadb * Remove timer and unused functions * Added filter by distance * Added test utils * Fix format * Fix type hint of dict * Rename test * Add test chromadb * Fix test no chromadb * Add coverage * Don't skip test vectordb utils * Add types * Fix tests * Fix docs build error * Add types to base * Update base * Update utils * Update chromadb * Add get_docs_by_ids * Improve docstring * Update init params * Update init vector db * Add get all docs * Move chroma_results_to_query_results to utils * Add init vectordb * Convert format of results for old version * Improve type hints * Update get_context for new query results format * Fix typo * Improve init db * Update default folder * Update logger * Update init, add embedding func * Update distance_threshold * Fix logger name * Update qdrant * Fix init db * Update notebooks * Use kwargs to improve readability * Improve docstring of vectordb, add two attributes * Add db_config * Update gitignore * Update comments * Add source * Fix file downloaded from urls have the same name * Remove files added by mistake * Improve docstring * Update docstring Co-authored-by: Chi Wang * Update docstring * Update docstring --------- Co-authored-by: Chi Wang --- .gitignore | 1 + .../qdrant_retrieve_user_proxy_agent.py | 16 +- .../contrib/retrieve_user_proxy_agent.py | 205 +++- .../agentchat/contrib/vectordb/chromadb.py | 4 +- autogen/agentchat/contrib/vectordb/utils.py | 3 + autogen/retrieve_utils.py | 24 +- notebook/agentchat_RetrieveChat.ipynb | 694 +++++-------- notebook/agentchat_groupchat_RAG.ipynb | 953 +++++++++++------- notebook/agentchat_qdrant_RetrieveChat.ipynb | 751 +++++++------- 9 files changed, 1388 insertions(+), 1263 deletions(-) diff --git a/.gitignore b/.gitignore index e5e6ff013d..93a18c4943 100644 --- a/.gitignore +++ b/.gitignore @@ -183,6 +183,7 @@ test/agentchat/test_agent_scripts/* # test cache .cache_test .db +local_cache notebook/result.png diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index c68ce809d8..1ece138963 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -1,17 +1,21 @@ -import logging from typing import Callable, Dict, List, Optional from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent +from autogen.agentchat.contrib.vectordb.utils import ( + chroma_results_to_query_results, + filter_results_by_distance, + get_logger, +) from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks -logger = logging.getLogger(__name__) +logger = get_logger(__name__) try: import fastembed from qdrant_client import QdrantClient, models from qdrant_client.fastembed_common import QueryResponse except ImportError as e: - logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'") + logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'") raise e @@ -136,6 +140,11 @@ class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent): collection_name=self._collection_name, embedding_model=self._embedding_model, ) + results["contents"] = results.pop("documents") + results = chroma_results_to_query_results(results, "distances") + results = filter_results_by_distance(results, self._distance_threshold) + + self._search_string = search_string self._results = results @@ -298,6 +307,7 @@ def query_qdrant( data = { "ids": [[result.id for result in sublist] for sublist in results], "documents": [[result.document for result in sublist] for sublist in results], + "distances": [[result.score for result in sublist] for sublist in results], "metadatas": [[result.metadata for result in sublist] for sublist in results], } return data diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 34dbe28d09..476c7c0739 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -1,3 +1,5 @@ +import hashlib +import os import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -7,15 +9,28 @@ try: import chromadb except ImportError: raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`") -from autogen import logger from autogen.agentchat import UserProxyAgent from autogen.agentchat.agent import Agent +from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory +from autogen.agentchat.contrib.vectordb.utils import ( + chroma_results_to_query_results, + filter_results_by_distance, + get_logger, +) from autogen.code_utils import extract_code -from autogen.retrieve_utils import TEXT_FORMATS, create_vector_db_from_dir, query_vector_db +from autogen.retrieve_utils import ( + TEXT_FORMATS, + create_vector_db_from_dir, + get_files_from_dir, + query_vector_db, + split_files_to_chunks, +) from autogen.token_count_utils import count_token from ...formatting_utils import colored +logger = get_logger(__name__) + PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the context provided by the user. You should follow the following steps to answer a question: Step 1, you estimate the user's intent based on the question and context. The intent can be a code generation task or @@ -65,6 +80,8 @@ User's question is: {input_question} Context is: {input_context} """ +HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8)) + class RetrieveUserProxyAgent(UserProxyAgent): """(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding @@ -107,9 +124,17 @@ class RetrieveUserProxyAgent(UserProxyAgent): "code", "qa" and "default". System prompt will be different for different tasks. The default value is `default`, which supports both code and qa, and provides source information in the end of the response. + - `vector_db` (Optional, Union[str, VectorDB]) - the vector db for the retrieve chat. + If it's a string, it should be the type of the vector db, such as "chroma"; otherwise, + it should be an instance of the VectorDB protocol. Default is "chroma". + Set `None` to use the deprecated `client`. + - `db_config` (Optional, Dict) - the config for the vector db. Default is `{}`. Please make + sure you understand the config for the vector db you are using, otherwise, leave it as `{}`. + Only valid when `vector_db` is a string. - `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a default client `chromadb.Client()` will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. + **Deprecated**: use `vector_db` instead. - `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It can also be the path to a single file, the url to a single file or a list of directories, files and urls. Default is None, which works only if the @@ -123,8 +148,11 @@ class RetrieveUserProxyAgent(UserProxyAgent): By default, "extra_docs" is set to false, starting document IDs from zero. This poses a risk as new documents might overwrite existing ones, potentially causing unintended loss or alteration of data in the collection. - - `collection_name` (Optional, str) - the name of the collection. - If key not provided, a default name `autogen-docs` will be used. + **Deprecated**: use `new_docs` when use `vector_db` instead of `client`. + - `new_docs` (Optional, bool) - when True, only adds new documents to the collection; + when False, updates existing documents and adds new ones. Default is True. + Document id is used to determine if a document is new or existing. By default, the + id is the hash value of the content. - `model` (Optional, str) - the model to use for the retrieve chat. If key not provided, a default model `gpt-4` will be used. - `chunk_token_size` (Optional, int) - the chunk token size for the retrieve chat. @@ -143,6 +171,7 @@ class RetrieveUserProxyAgent(UserProxyAgent): models can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended. + **Deprecated**: no need when use `vector_db` instead of `client`. - `embedding_function` (Optional, Callable) - the embedding function for creating the vector db. Default is None, SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding @@ -156,10 +185,14 @@ class RetrieveUserProxyAgent(UserProxyAgent): `Update Context` will be triggered. - `update_context` (Optional, bool) - if False, will not apply `Update Context` for interactive retrieval. Default is True. - - `get_or_create` (Optional, bool) - if True, will create/return a collection for the - retrieve chat. This is the same as that used in chromadb. - Default is False. Will raise ValueError if the collection already exists and - get_or_create is False. Will be set to True if docs_path is None. + - `collection_name` (Optional, str) - the name of the collection. + If key not provided, a default name `autogen-docs` will be used. + - `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is True. + - `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. - `custom_token_count_function` (Optional, Callable) - a custom function to count the number of tokens in a string. The function should take (text:str, model:str) as input and return the @@ -176,6 +209,8 @@ class RetrieveUserProxyAgent(UserProxyAgent): included files and urls will be chunked regardless of their types. - `recursive` (Optional, bool) - whether to search documents recursively in the docs_path. Default is True. + - `distance_threshold` (Optional, float) - the threshold for the distance score, only + distance smaller than it will be returned. Will be ignored if < 0. Default is -1. `**kwargs` (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). @@ -183,6 +218,7 @@ class RetrieveUserProxyAgent(UserProxyAgent): Example of overriding retrieve_docs - If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code. + **Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent. ```python class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): def query_vector_db( @@ -215,9 +251,12 @@ class RetrieveUserProxyAgent(UserProxyAgent): self._retrieve_config = {} if retrieve_config is None else retrieve_config self._task = self._retrieve_config.get("task", "default") + self._vector_db = self._retrieve_config.get("vector_db", "chroma") + self._db_config = self._retrieve_config.get("db_config", {}) self._client = self._retrieve_config.get("client", chromadb.Client()) self._docs_path = self._retrieve_config.get("docs_path", None) self._extra_docs = self._retrieve_config.get("extra_docs", False) + self._new_docs = self._retrieve_config.get("new_docs", True) self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs") if "docs_path" not in self._retrieve_config: logger.warning( @@ -236,6 +275,7 @@ class RetrieveUserProxyAgent(UserProxyAgent): self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() self.update_context = self._retrieve_config.get("update_context", True) self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True + self._overwrite = self._retrieve_config.get("overwrite", False) self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token) self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None) self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS) @@ -244,18 +284,95 @@ class RetrieveUserProxyAgent(UserProxyAgent): self._collection = True if self._docs_path is None else False # whether the collection is created self._ipython = get_ipython() self._doc_idx = -1 # the index of the current used doc - self._results = {} # the results of the current query + self._results = [] # the results of the current query self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc self._current_docs_in_context = [] # the ids of the current context sources self._search_string = "" # the search string used in the current query + self._distance_threshold = self._retrieve_config.get("distance_threshold", -1) # update the termination message function self._is_termination_msg = ( self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg ) + if isinstance(self._vector_db, str): + if not isinstance(self._db_config, dict): + raise ValueError("`db_config` should be a dictionary.") + if "embedding_function" in self._retrieve_config: + self._db_config["embedding_function"] = self._embedding_function + self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config) self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2) + def _init_db(self): + if not self._vector_db: + return + + IS_TO_CHUNK = False # whether to chunk the raw files + if self._new_docs: + IS_TO_CHUNK = True + if not self._docs_path: + try: + self._vector_db.get_collection(self._collection_name) + logger.warning(f"`docs_path` is not provided. Use the existing collection `{self._collection_name}`.") + self._overwrite = False + self._get_or_create = True + IS_TO_CHUNK = False + except ValueError: + raise ValueError( + "`docs_path` is not provided. " + f"The collection `{self._collection_name}` doesn't exist either. " + "Please provide `docs_path` or create the collection first." + ) + elif self._get_or_create and not self._overwrite: + try: + self._vector_db.get_collection(self._collection_name) + logger.info(f"Use the existing collection `{self._collection_name}`.", color="green") + except ValueError: + IS_TO_CHUNK = True + else: + IS_TO_CHUNK = True + + self._vector_db.active_collection = self._vector_db.create_collection( + self._collection_name, overwrite=self._overwrite, get_or_create=self._get_or_create + ) + + docs = None + if IS_TO_CHUNK: + if self.custom_text_split_function is not None: + chunks, sources = split_files_to_chunks( + get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive), + custom_text_split_function=self.custom_text_split_function, + ) + else: + chunks, sources = split_files_to_chunks( + get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive), + self._max_tokens, + self._chunk_mode, + self._must_break_at_empty_line, + ) + logger.info(f"Found {len(chunks)} chunks.") + + if self._new_docs: + all_docs_ids = set( + [ + doc["id"] + for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name) + ] + ) + else: + all_docs_ids = set() + + chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks] + chunk_ids_set = set(chunk_ids) + chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set] + docs = [ + Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx]) + for idx in chunk_ids_set_idx + if chunk_ids[idx] not in all_docs_ids + ] + + self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True) + def _is_termination_msg_retrievechat(self, message): """Check if a message is a termination message. For code generation, terminate when no code block is detected. Currently only detect python code blocks. @@ -288,41 +405,42 @@ class RetrieveUserProxyAgent(UserProxyAgent): def _reset(self, intermediate=False): self._doc_idx = -1 # the index of the current used doc - self._results = {} # the results of the current query + self._results = [] # the results of the current query if not intermediate: self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc - def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]): + def _get_context(self, results: QueryResults): doc_contents = "" self._current_docs_in_context = [] current_tokens = 0 _doc_idx = self._doc_idx _tmp_retrieve_count = 0 - for idx, doc in enumerate(results["documents"][0]): + for idx, doc in enumerate(results[0]): + doc = doc[0] if idx <= _doc_idx: continue - if results["ids"][0][idx] in self._doc_ids: + if doc["id"] in self._doc_ids: continue - _doc_tokens = self.custom_token_count_function(doc, self._model) + _doc_tokens = self.custom_token_count_function(doc["content"], self._model) if _doc_tokens > self._context_max_tokens: - func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context." + func_print = f"Skip doc_id {doc['id']} as it is too long to fit in the context." print(colored(func_print, "green"), flush=True) self._doc_idx = idx continue if current_tokens + _doc_tokens > self._context_max_tokens: break - func_print = f"Adding doc_id {results['ids'][0][idx]} to context." + func_print = f"Adding content of doc {doc['id']} to context." print(colored(func_print, "green"), flush=True) current_tokens += _doc_tokens - doc_contents += doc + "\n" - _metadatas = results.get("metadatas") - if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict): - self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", "")) + doc_contents += doc["content"] + "\n" + _metadata = doc.get("metadata") + if isinstance(_metadata, dict): + self._current_docs_in_context.append(_metadata.get("source", "")) self._doc_idx = idx - self._doc_ids.append(results["ids"][0][idx]) - self._doc_contents.append(doc) + self._doc_ids.append(doc["id"]) + self._doc_contents.append(doc["content"]) _tmp_retrieve_count += 1 if _tmp_retrieve_count >= self.n_results: break @@ -416,21 +534,40 @@ class RetrieveUserProxyAgent(UserProxyAgent): def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): """Retrieve docs based on the given problem and assign the results to the class property `_results`. - In case you want to customize the retrieval process, such as using a different vector db whose APIs are not - compatible with chromadb or filter results with metadata, you can override this function. Just keep the current - parameters and add your own parameters with default values, and keep the results in below type. - - Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of - the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer - to `chromadb.api.types.QueryResult` as an example. - ids: List[string] - documents: List[List[string]] + The retrieved docs should be type of `QueryResults` which is a list of tuples containing the document and + the distance. Args: problem (str): the problem to be solved. n_results (int): the number of results to be retrieved. Default is 20. search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "". + Not used if the vector_db doesn't support it. + + Returns: + None. """ + if isinstance(self._vector_db, VectorDB): + if not self._collection or not self._get_or_create: + print("Trying to create collection.") + self._init_db() + self._collection = True + self._get_or_create = True + + kwargs = {} + if hasattr(self._vector_db, "type") and self._vector_db.type == "chroma": + kwargs["where_document"] = {"$contains": search_string} if search_string else None + results = self._vector_db.retrieve_docs( + queries=[problem], + n_results=n_results, + collection_name=self._collection_name, + distance_threshold=self._distance_threshold, + **kwargs, + ) + self._search_string = search_string + self._results = results + print("VectorDB returns doc_ids: ", [[r[0]["id"] for r in rr] for rr in results]) + return + if not self._collection or not self._get_or_create: print("Trying to create collection.") self._client = create_vector_db_from_dir( @@ -460,9 +597,13 @@ class RetrieveUserProxyAgent(UserProxyAgent): embedding_model=self._embedding_model, embedding_function=self._embedding_function, ) + results["contents"] = results.pop("documents") + results = chroma_results_to_query_results(results, "distances") + results = filter_results_by_distance(results, self._distance_threshold) + self._search_string = search_string self._results = results - print("doc_ids: ", results["ids"]) + print("doc_ids: ", [[r[0]["id"] for r in rr] for rr in results]) @staticmethod def message_generator(sender, recipient, context): diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 6e571d58ab..3f1fbc86a4 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -24,7 +24,7 @@ class ChromaVectorDB(VectorDB): """ def __init__( - self, *, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs + self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs ) -> None: """ Initialize the vector database. @@ -32,7 +32,7 @@ class ChromaVectorDB(VectorDB): Args: client: chromadb.Client | The client object of the vector database. Default is None. If provided, it will use the client object directly and ignore other arguments. - path: str | The path to the vector database. Default is None. + path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24. embedding_function: Callable | The embedding function used to generate the vector representation of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used. metadata: dict | The metadata of the vector database. Default is None. If None, it will use this diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index ae1ef12525..3dcf79f1f5 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -25,6 +25,9 @@ class ColoredLogger(logging.Logger): def critical(self, msg, *args, color="red", **kwargs): super().critical(colored(msg, color), *args, **kwargs) + def fatal(self, msg, *args, color="red", **kwargs): + super().fatal(colored(msg, color), *args, **kwargs) + def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: logger = ColoredLogger(name, level) diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index e83f8a80f3..9393903ec8 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -1,4 +1,5 @@ import glob +import hashlib import os import re from typing import Callable, List, Tuple, Union @@ -156,7 +157,7 @@ def split_files_to_chunks( chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True, custom_text_split_function: Callable = None, -): +) -> Tuple[List[str], List[dict]]: """Split a list of files into chunks of max_tokens.""" chunks = [] @@ -275,15 +276,22 @@ def parse_html_to_markdown(html: str, url: str = None) -> str: return webpage_text +def _generate_file_name_from_url(url: str, max_length=255) -> str: + url_bytes = url.encode("utf-8") + hash = hashlib.blake2b(url_bytes).hexdigest() + parsed_url = urlparse(url) + file_name = os.path.basename(url) + file_name = f"{parsed_url.netloc}_{file_name}_{hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}" + return file_name + + def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]: """Download a file from a URL.""" if save_path is None: save_path = "tmp/chromadb" os.makedirs(save_path, exist_ok=True) if os.path.isdir(save_path): - filename = os.path.basename(url) - if filename == "": # "www.example.com/" - filename = url.split("/")[-2] + filename = _generate_file_name_from_url(url) save_path = os.path.join(save_path, filename) else: os.makedirs(os.path.dirname(save_path), exist_ok=True) @@ -327,7 +335,7 @@ def create_vector_db_from_dir( dir_path: Union[str, List[str]], max_tokens: int = 4000, client: API = None, - db_path: str = "/tmp/chromadb.db", + db_path: str = "tmp/chromadb.db", collection_name: str = "all-my-documents", get_or_create: bool = False, chunk_mode: str = "multi_lines", @@ -347,7 +355,7 @@ def create_vector_db_from_dir( dir_path (Union[str, List[str]]): the path to the directory, file, url or a list of them. max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000. client (Optional, API): the chromadb client. Default is None. - db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db". + db_path (Optional, str): the path to the chromadb. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24. collection_name (Optional, str): the name of the collection. Default is "all-my-documents". get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection will be returned if it already exists. Will raise ValueError if the collection already exists and get_or_create is False. @@ -420,7 +428,7 @@ def query_vector_db( query_texts: List[str], n_results: int = 10, client: API = None, - db_path: str = "/tmp/chromadb.db", + db_path: str = "tmp/chromadb.db", collection_name: str = "all-my-documents", search_string: str = "", embedding_model: str = "all-MiniLM-L6-v2", @@ -433,7 +441,7 @@ def query_vector_db( query_texts (List[str]): the list of strings which will be used to query the vector db. n_results (Optional, int): the number of results to return. Default is 10. client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used. - db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db". + db_path (Optional, str): the path to the vector db. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24. collection_name (Optional, str): the name of the collection. Default is "all-my-documents". search_string (Optional, str): the search string. Only docs that contain an exact match of this string will be retrieved. Default is "". embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if diff --git a/notebook/agentchat_RetrieveChat.ipynb b/notebook/agentchat_RetrieveChat.ipynb index 0ff689a8ec..c0c681350f 100644 --- a/notebook/agentchat_RetrieveChat.ipynb +++ b/notebook/agentchat_RetrieveChat.ipynb @@ -48,14 +48,14 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "models to use: ['gpt-35-turbo']\n" + "models to use: ['gpt-35-turbo', 'gpt-35-turbo-0613']\n" ] } ], @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -105,7 +105,7 @@ "output_type": "stream", "text": [ "Accepted file formats for `docs_path`:\n", - "['xml', 'htm', 'msg', 'docx', 'org', 'pptx', 'jsonl', 'txt', 'tsv', 'yml', 'json', 'md', 'pdf', 'xlsx', 'csv', 'html', 'log', 'yaml', 'doc', 'odt', 'rtf', 'ppt', 'epub', 'rst']\n" + "['ppt', 'jsonl', 'csv', 'yaml', 'rst', 'htm', 'pdf', 'tsv', 'doc', 'docx', 'pptx', 'msg', 'yml', 'xml', 'md', 'json', 'txt', 'epub', 'org', 'xlsx', 'log', 'html', 'odt', 'rtf']\n" ] } ], @@ -116,9 +116,18 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n" + ] + } + ], "source": [ "# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n", "assistant = RetrieveAssistantAgent(\n", @@ -139,7 +148,7 @@ "# `chunk_token_size` is the chunk token size for the retrieve chat. By default, it is set to `max_tokens * 0.6`, here we set it to 2000.\n", "# `custom_text_types` is a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.\n", "# This only applies to files under the directories in `docs_path`. Explicitly included files and urls will be chunked regardless of their types.\n", - "# In this example, we set it to [\"mdx\"] to only process markdown files. Since no mdx files are included in the `websit/docs`,\n", + "# In this example, we set it to [\"non-existent-type\"] to only process markdown files. Since no \"non-existent-type\" files are included in the `websit/docs`,\n", "# no files there will be processed. However, the explicitly included urls will still be processed.\n", "ragproxyagent = RetrieveUserProxyAgent(\n", " name=\"ragproxyagent\",\n", @@ -152,12 +161,12 @@ " \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md\",\n", " os.path.join(os.path.abspath(\"\"), \"..\", \"website\", \"docs\"),\n", " ],\n", - " \"custom_text_types\": [\"mdx\"],\n", + " \"custom_text_types\": [\"non-existent-type\"],\n", " \"chunk_token_size\": 2000,\n", " \"model\": config_list[0][\"model\"],\n", - " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n", - " \"embedding_model\": \"all-mpnet-base-v2\",\n", - " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection, but you'll need to remove the collection manually\n", + " # \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"), # deprecated, use \"vector_db\" instead\n", + " \"vector_db\": \"chroma\", # to use the deprecated `client` parameter, set to None and uncomment the line above\n", + " \"overwrite\": False, # set to True if you want to overwrite an existing collection\n", " },\n", " code_execution_config=False, # set to False if you don't want to execute the code\n", ")" @@ -179,14 +188,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "INFO:autogen.retrieve_utils:Found 2 chunks.\n" + "2024-04-07 17:30:56,955 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `autogen-docs`.\u001b[0m\n" ] }, { @@ -200,15 +209,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" + "2024-04-07 17:30:59,609 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n", + "Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", + "VectorDB returns doc_ids: [['bdfbc921']]\n", + "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -226,6 +236,7 @@ "Context is: # Integrate - Spark\n", "\n", "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", "- Use Spark ML estimators for AutoML.\n", "- Use Spark to run training in parallel spark jobs.\n", "\n", @@ -240,6 +251,7 @@ "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", "\n", "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", "- `index_col` is the column name to use as the index, default is None.\n", "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", "\n", @@ -248,10 +260,13 @@ "```python\n", "import pandas as pd\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "\n", "# Creating a dictionary\n", - "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", "\n", "# Creating a pandas DataFrame\n", "dataframe = pd.DataFrame(data)\n", @@ -264,8 +279,10 @@ "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", "\n", "Here is an example of how to use it:\n", + "\n", "```python\n", "from pyspark.ml.feature import VectorAssembler\n", + "\n", "columns = psdf.columns\n", "feature_cols = [col for col in columns if col != label]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", @@ -275,10 +292,13 @@ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", "\n", "### Estimators\n", + "\n", "#### Model List\n", + "\n", "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", "\n", "#### Usage\n", + "\n", "First, prepare your data in the required format as described in the previous section.\n", "\n", "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", @@ -287,6 +307,7 @@ "\n", "```python\n", "import flaml\n", + "\n", "# prepare your data in pandas-on-spark format as we previously mentioned\n", "\n", "automl = flaml.AutoML()\n", @@ -304,24 +325,25 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", "\n", "## Parallel Spark Jobs\n", + "\n", "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", "\n", "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", "\n", "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", "\n", - "\n", "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", - "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performs parallel tuning.\n", + "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", "\n", "An example code snippet for using parallel Spark jobs:\n", + "\n", "```python\n", "import flaml\n", + "\n", "automl_experiment = flaml.AutoML()\n", "automl_settings = {\n", " \"time_budget\": 30,\n", @@ -329,7 +351,7 @@ " \"task\": \"regression\",\n", " \"n_concurrent_trials\": 2,\n", " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", "}\n", "\n", "automl.fit(\n", @@ -339,51 +361,72 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", "\n", "\n", "\n", - "\n", "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "You can use FLAML's `lgbm_spark` estimator for classification tasks and activate Spark as the parallel backend during training by setting `use_spark` to `True`. Here is an example code snippet:\n", + "To perform a classification task using FLAML and use Spark to do parallel training for 30 seconds and force cancel jobs if the time limit is reached, you can follow these steps:\n", + "\n", + "1. First, convert your data into Spark dataframe format using `to_pandas_on_spark` function from `flaml.automl.spark.utils` module.\n", + "2. Then, format your data for use SparkML models by using `VectorAssembler`.\n", + "3. Define your AutoML settings, including the `metric`, `time_budget`, and `task`.\n", + "4. Use `AutoML` from `flaml` to run AutoML with SparkML models by setting `use_spark` to `true`, and `estimator_list` to a list of spark-based estimators, like `[\"lgbm_spark\"]`.\n", + "5. Set `n_concurrent_trials` to the desired number of parallel jobs and `force_cancel` to `True` to cancel the jobs if the time limit is reached.\n", + "\n", + "Here's an example code snippet for performing classification using FLAML and Spark:\n", "\n", "```python\n", - "import flaml\n", + "import pandas as pd\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", "from pyspark.ml.feature import VectorAssembler\n", + "import flaml\n", "\n", - "# Assuming you have a Spark DataFrame named 'df' that contains your data\n", - "dataframe = df.toPandas()\n", - "label = \"target\"\n", + "# Creating a dictionary\n", + "data = {\n", + " \"sepal_length\": [5.1, 4.9, 4.7, 4.6, 5.0],\n", + " \"sepal_width\": [3.5, 3.0, 3.2, 3.1, 3.6],\n", + " \"petal_length\": [1.4, 1.4, 1.3, 1.5, 1.4],\n", + " \"petal_width\": [0.2, 0.2, 0.2, 0.2, 0.2],\n", + " \"species\": [\"setosa\", \"setosa\", \"setosa\", \"setosa\", \"setosa\"]\n", + "}\n", + "\n", + "# Creating a pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"species\"\n", + "\n", + "# Convert to pandas-on-spark dataframe\n", "psdf = to_pandas_on_spark(dataframe)\n", "\n", + "# Format data for SparkML models\n", "columns = psdf.columns\n", "feature_cols = [col for col in columns if col != label]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", "\n", - "# configure and run AutoML\n", - "automl = flaml.AutoML()\n", + "# Define AutoML settings\n", "settings = {\n", " \"time_budget\": 30,\n", " \"metric\": \"accuracy\",\n", - " \"estimator_list\": [\"lgbm_spark\"],\n", " \"task\": \"classification\",\n", - " \"n_jobs\": -1, # Use all available CPUs\n", - " \"use_spark\": True, # Use Spark as the parallel backend\n", - " \"force_cancel\": True # Halt Spark jobs that run for longer than the time budget\n", "}\n", + "\n", + "# Use AutoML with SparkML models and parallel jobs\n", + "automl = flaml.AutoML()\n", "automl.fit(\n", " dataframe=psdf,\n", " label=label,\n", + " estimator_list=[\"lgbm_spark\"],\n", + " use_spark=True,\n", + " n_concurrent_trials=2,\n", + " force_cancel=True,\n", " **settings,\n", ")\n", "```\n", "\n", - "Note that you should not use `use_spark` if you are working with Spark data, because SparkML models already run in parallel.\n", + "Note that the above code assumes the data is small enough to train within 30 seconds. If you have a larger dataset, you may need to increase the `time_budget` and adjust the number of parallel jobs accordingly.\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", @@ -403,56 +446,49 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 60 is greater than number of elements in index 2, updating n_results = 2\n", - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 100 is greater than number of elements in index 2, updating n_results = 2\n" + "Number of requested results 60 is greater than number of elements in index 2, updating n_results = 2\n", + "Number of requested results 100 is greater than number of elements in index 2, updating n_results = 2\n", + "Number of requested results 140 is greater than number of elements in index 2, updating n_results = 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0']]\n", - "doc_ids: [['doc_0']]\n" + "VectorDB returns doc_ids: [['bdfbc921']]\n", + "VectorDB returns doc_ids: [['bdfbc921']]\n", + "VectorDB returns doc_ids: [['bdfbc921']]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 140 is greater than number of elements in index 2, updating n_results = 2\n" + "Number of requested results 180 is greater than number of elements in index 2, updating n_results = 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0']]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 180 is greater than number of elements in index 2, updating n_results = 2\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "doc_ids: [['doc_0']]\n", + "VectorDB returns doc_ids: [['bdfbc921']]\n", "\u001b[32mNo more context, will terminate.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "TERMINATE\n", "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "TERMINATE\n", - "\n", "--------------------------------------------------------------------------------\n" ] + }, + { + "data": { + "text/plain": [ + "ChatResult(chat_id=None, chat_history=[{'content': 'TERMINATE', 'role': 'assistant'}], summary='', cost=({'total_cost': 0.007691, 'gpt-35-turbo': {'cost': 0.007691, 'prompt_tokens': 4242, 'completion_tokens': 664, 'total_tokens': 4906}}, {'total_cost': 0}), human_input=[])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -485,23 +521,23 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:chromadb.segment.impl.vector.local_persistent_hnsw:Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" + "Number of requested results 20 is greater than number of elements in index 2, updating n_results = 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0', 'doc_1']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", + "VectorDB returns doc_ids: [['7968cf3c', 'bdfbc921']]\n", + "\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n", + "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -516,9 +552,124 @@ "\n", "User's question is: Who is the author of FLAML?\n", "\n", - "Context is: # Integrate - Spark\n", + "Context is: # Research\n", + "\n", + "For technical details, please check our research publications.\n", + "\n", + "- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2021flaml,\n", + " title={FLAML: A Fast and Lightweight AutoML Library},\n", + " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", + " year={2021},\n", + " booktitle={MLSys},\n", + "}\n", + "```\n", + "\n", + "- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2021cfo,\n", + " title={Frugal Optimization for Cost-related Hyperparameters},\n", + " author={Qingyun Wu and Chi Wang and Silu Huang},\n", + " year={2021},\n", + " booktitle={AAAI},\n", + "}\n", + "```\n", + "\n", + "- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2021blendsearch,\n", + " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", + " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", + " year={2021},\n", + " booktitle={ICLR},\n", + "}\n", + "```\n", + "\n", + "- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{liuwang2021hpolm,\n", + " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", + " author={Susan Xueqing Liu and Chi Wang},\n", + " year={2021},\n", + " booktitle={ACL},\n", + "}\n", + "```\n", + "\n", + "- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2021chacha,\n", + " title={ChaCha for Online AutoML},\n", + " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", + " year={2021},\n", + " booktitle={ICML},\n", + "}\n", + "```\n", + "\n", + "- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", + "\n", + "```bibtex\n", + "@inproceedings{wuwang2021fairautoml,\n", + " title={Fair AutoML},\n", + " author={Qingyun Wu and Chi Wang},\n", + " year={2021},\n", + " booktitle={ArXiv preprint arXiv:2111.06495},\n", + "}\n", + "```\n", + "\n", + "- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", + "\n", + "```bibtex\n", + "@inproceedings{kayaliwang2022default,\n", + " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", + " author={Moe Kayali and Chi Wang},\n", + " year={2022},\n", + " booktitle={ArXiv preprint arXiv:2202.09927},\n", + "}\n", + "```\n", + "\n", + "- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", + "\n", + "```bibtex\n", + "@inproceedings{zhang2023targeted,\n", + " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", + " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", + " booktitle={International Conference on Learning Representations},\n", + " year={2023},\n", + " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", + "}\n", + "```\n", + "\n", + "- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", + "\n", + "```bibtex\n", + "@inproceedings{wang2023EcoOptiGen,\n", + " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", + " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", + " year={2023},\n", + " booktitle={ArXiv preprint arXiv:2303.04673},\n", + "}\n", + "```\n", + "\n", + "- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", + "\n", + "```bibtex\n", + "@inproceedings{wu2023empirical,\n", + " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", + " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", + " year={2023},\n", + " booktitle={ArXiv preprint arXiv:2306.01337},\n", + "}\n", + "```\n", + "# Integrate - Spark\n", "\n", "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", "- Use Spark ML estimators for AutoML.\n", "- Use Spark to run training in parallel spark jobs.\n", "\n", @@ -533,6 +684,7 @@ "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", "\n", "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", "- `index_col` is the column name to use as the index, default is None.\n", "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", "\n", @@ -541,10 +693,13 @@ "```python\n", "import pandas as pd\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "\n", "# Creating a dictionary\n", - "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", "\n", "# Creating a pandas DataFrame\n", "dataframe = pd.DataFrame(data)\n", @@ -557,8 +712,10 @@ "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", "\n", "Here is an example of how to use it:\n", + "\n", "```python\n", "from pyspark.ml.feature import VectorAssembler\n", + "\n", "columns = psdf.columns\n", "feature_cols = [col for col in columns if col != label]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", @@ -568,10 +725,13 @@ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", "\n", "### Estimators\n", + "\n", "#### Model List\n", + "\n", "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", "\n", "#### Usage\n", + "\n", "First, prepare your data in the required format as described in the previous section.\n", "\n", "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", @@ -580,6 +740,7 @@ "\n", "```python\n", "import flaml\n", + "\n", "# prepare your data in pandas-on-spark format as we previously mentioned\n", "\n", "automl = flaml.AutoML()\n", @@ -597,24 +758,25 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", "\n", "## Parallel Spark Jobs\n", + "\n", "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", "\n", "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", "\n", "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", "\n", - "\n", "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", - "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performs parallel tuning.\n", + "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", "\n", "An example code snippet for using parallel Spark jobs:\n", + "\n", "```python\n", "import flaml\n", + "\n", "automl_experiment = flaml.AutoML()\n", "automl_settings = {\n", " \"time_budget\": 30,\n", @@ -622,7 +784,7 @@ " \"task\": \"regression\",\n", " \"n_concurrent_trials\": 2,\n", " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", "}\n", "\n", "automl.fit(\n", @@ -632,387 +794,27 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", "\n", - "# Research\n", - "\n", - "For technical details, please check our research publications.\n", - "\n", - "* [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021flaml,\n", - " title={FLAML: A Fast and Lightweight AutoML Library},\n", - " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", - " year={2021},\n", - " booktitle={MLSys},\n", - "}\n", - "```\n", - "\n", - "* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021cfo,\n", - " title={Frugal Optimization for Cost-related Hyperparameters},\n", - " author={Qingyun Wu and Chi Wang and Silu Huang},\n", - " year={2021},\n", - " booktitle={AAAI},\n", - "}\n", - "```\n", - "\n", - "* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021blendsearch,\n", - " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", - " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", - " year={2021},\n", - " booktitle={ICLR},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{liuwang2021hpolm,\n", - " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", - " author={Susan Xueqing Liu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ACL},\n", - "}\n", - "```\n", - "\n", - "* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021chacha,\n", - " title={ChaCha for Online AutoML},\n", - " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", - " year={2021},\n", - " booktitle={ICML},\n", - "}\n", - "```\n", - "\n", - "* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", - "\n", - "```bibtex\n", - "@inproceedings{wuwang2021fairautoml,\n", - " title={Fair AutoML},\n", - " author={Qingyun Wu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ArXiv preprint arXiv:2111.06495},\n", - "}\n", - "```\n", - "\n", - "* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", - "\n", - "```bibtex\n", - "@inproceedings{kayaliwang2022default,\n", - " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", - " author={Moe Kayali and Chi Wang},\n", - " year={2022},\n", - " booktitle={ArXiv preprint arXiv:2202.09927},\n", - "}\n", - "```\n", - "\n", - "* [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", - "\n", - "```bibtex\n", - "@inproceedings{zhang2023targeted,\n", - " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", - " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", - " booktitle={International Conference on Learning Representations},\n", - " year={2023},\n", - " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", - "}\n", - "```\n", - "\n", - "* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2023EcoOptiGen,\n", - " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", - " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2303.04673},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2023empirical,\n", - " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", - " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2306.01337},\n", - "}\n", - "```\n", - "\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", - "context provided by the user.\n", - "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", - "For code generation, you must obey the following rules:\n", - "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", - "Rule 2. You must follow the formats below to write your code:\n", - "```language\n", - "# your code\n", - "```\n", - "\n", - "User's question is: Who is the author of FLAML?\n", - "\n", - "Context is: # Integrate - Spark\n", - "\n", - "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", - "- Use Spark ML estimators for AutoML.\n", - "- Use Spark to run training in parallel spark jobs.\n", - "\n", - "## Spark ML Estimators\n", - "\n", - "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", - "\n", - "### Data\n", - "\n", - "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", - "\n", - "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", - "\n", - "This function also accepts optional arguments `index_col` and `default_index_type`.\n", - "- `index_col` is the column name to use as the index, default is None.\n", - "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", - "\n", - "Here is an example code snippet for Spark Data:\n", - "\n", - "```python\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "# Creating a dictionary\n", - "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", - "\n", - "# Creating a pandas DataFrame\n", - "dataframe = pd.DataFrame(data)\n", - "label = \"Price\"\n", - "\n", - "# Convert to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "```\n", - "\n", - "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", - "\n", - "Here is an example of how to use it:\n", - "```python\n", - "from pyspark.ml.feature import VectorAssembler\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "```\n", - "\n", - "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", - "\n", - "### Estimators\n", - "#### Model List\n", - "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", - "\n", - "#### Usage\n", - "First, prepare your data in the required format as described in the previous section.\n", - "\n", - "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", - "\n", - "Here is an example code snippet using SparkML models in AutoML:\n", - "\n", - "```python\n", - "import flaml\n", - "# prepare your data in pandas-on-spark format as we previously mentioned\n", - "\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", - " \"task\": \"regression\",\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings,\n", - ")\n", - "```\n", - "\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", - "\n", - "## Parallel Spark Jobs\n", - "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", - "\n", - "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", - "\n", - "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", - "\n", - "\n", - "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", - "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performs parallel tuning.\n", - "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", - "\n", - "An example code snippet for using parallel Spark jobs:\n", - "```python\n", - "import flaml\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings,\n", - ")\n", - "```\n", - "\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", - "\n", - "# Research\n", - "\n", - "For technical details, please check our research publications.\n", - "\n", - "* [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021flaml,\n", - " title={FLAML: A Fast and Lightweight AutoML Library},\n", - " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", - " year={2021},\n", - " booktitle={MLSys},\n", - "}\n", - "```\n", - "\n", - "* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021cfo,\n", - " title={Frugal Optimization for Cost-related Hyperparameters},\n", - " author={Qingyun Wu and Chi Wang and Silu Huang},\n", - " year={2021},\n", - " booktitle={AAAI},\n", - "}\n", - "```\n", - "\n", - "* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021blendsearch,\n", - " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", - " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", - " year={2021},\n", - " booktitle={ICLR},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{liuwang2021hpolm,\n", - " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", - " author={Susan Xueqing Liu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ACL},\n", - "}\n", - "```\n", - "\n", - "* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021chacha,\n", - " title={ChaCha for Online AutoML},\n", - " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", - " year={2021},\n", - " booktitle={ICML},\n", - "}\n", - "```\n", - "\n", - "* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", - "\n", - "```bibtex\n", - "@inproceedings{wuwang2021fairautoml,\n", - " title={Fair AutoML},\n", - " author={Qingyun Wu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ArXiv preprint arXiv:2111.06495},\n", - "}\n", - "```\n", - "\n", - "* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", - "\n", - "```bibtex\n", - "@inproceedings{kayaliwang2022default,\n", - " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", - " author={Moe Kayali and Chi Wang},\n", - " year={2022},\n", - " booktitle={ArXiv preprint arXiv:2202.09927},\n", - "}\n", - "```\n", - "\n", - "* [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", - "\n", - "```bibtex\n", - "@inproceedings{zhang2023targeted,\n", - " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", - " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", - " booktitle={International Conference on Learning Representations},\n", - " year={2023},\n", - " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", - "}\n", - "```\n", - "\n", - "* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2023EcoOptiGen,\n", - " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", - " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2303.04673},\n", - "}\n", - "```\n", - "\n", - "* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2023empirical,\n", - " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", - " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2306.01337},\n", - "}\n", - "```\n", - "\n", - "\n", "\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.\n", + "The author of FLAML is Chi Wang, along with several co-authors for various publications related to FLAML.\n", "\n", "--------------------------------------------------------------------------------\n" ] + }, + { + "data": { + "text/plain": [ + "ChatResult(chat_id=None, chat_history=[{'content': 'You\\'re a retrieve augmented coding assistant. You answer user\\'s questions based on your own knowledge and the\\ncontext provided by the user.\\nIf you can\\'t answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\\nFor code generation, you must obey the following rules:\\nRule 1. You MUST NOT install any packages because all the packages needed are already installed.\\nRule 2. You must follow the formats below to write your code:\\n```language\\n# your code\\n```\\n\\nUser\\'s question is: Who is the author of FLAML?\\n\\nContext is: # Research\\n\\nFor technical details, please check our research publications.\\n\\n- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\\n\\n```bibtex\\n@inproceedings{wang2021flaml,\\n title={FLAML: A Fast and Lightweight AutoML Library},\\n author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\\n year={2021},\\n booktitle={MLSys},\\n}\\n```\\n\\n- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\\n\\n```bibtex\\n@inproceedings{wu2021cfo,\\n title={Frugal Optimization for Cost-related Hyperparameters},\\n author={Qingyun Wu and Chi Wang and Silu Huang},\\n year={2021},\\n booktitle={AAAI},\\n}\\n```\\n\\n- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\\n\\n```bibtex\\n@inproceedings{wang2021blendsearch,\\n title={Economical Hyperparameter Optimization With Blended Search Strategy},\\n author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\\n year={2021},\\n booktitle={ICLR},\\n}\\n```\\n\\n- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\\n\\n```bibtex\\n@inproceedings{liuwang2021hpolm,\\n title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\\n author={Susan Xueqing Liu and Chi Wang},\\n year={2021},\\n booktitle={ACL},\\n}\\n```\\n\\n- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\\n\\n```bibtex\\n@inproceedings{wu2021chacha,\\n title={ChaCha for Online AutoML},\\n author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\\n year={2021},\\n booktitle={ICML},\\n}\\n```\\n\\n- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\\n\\n```bibtex\\n@inproceedings{wuwang2021fairautoml,\\n title={Fair AutoML},\\n author={Qingyun Wu and Chi Wang},\\n year={2021},\\n booktitle={ArXiv preprint arXiv:2111.06495},\\n}\\n```\\n\\n- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\\n\\n```bibtex\\n@inproceedings{kayaliwang2022default,\\n title={Mining Robust Default Configurations for Resource-constrained AutoML},\\n author={Moe Kayali and Chi Wang},\\n year={2022},\\n booktitle={ArXiv preprint arXiv:2202.09927},\\n}\\n```\\n\\n- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\\n\\n```bibtex\\n@inproceedings{zhang2023targeted,\\n title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\\n author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\\n booktitle={International Conference on Learning Representations},\\n year={2023},\\n url={https://openreview.net/forum?id=0Ij9_q567Ma},\\n}\\n```\\n\\n- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\\n\\n```bibtex\\n@inproceedings{wang2023EcoOptiGen,\\n title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\\n author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\\n year={2023},\\n booktitle={ArXiv preprint arXiv:2303.04673},\\n}\\n```\\n\\n- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\\n\\n```bibtex\\n@inproceedings{wu2023empirical,\\n title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\\n author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\\n year={2023},\\n booktitle={ArXiv preprint arXiv:2306.01337},\\n}\\n```\\n# Integrate - Spark\\n\\nFLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\\n\\n- Use Spark ML estimators for AutoML.\\n- Use Spark to run training in parallel spark jobs.\\n\\n## Spark ML Estimators\\n\\nFLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\\n\\n### Data\\n\\nFor Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\\n\\nThis utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\\n\\nThis function also accepts optional arguments `index_col` and `default_index_type`.\\n\\n- `index_col` is the column name to use as the index, default is None.\\n- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\\n\\nHere is an example code snippet for Spark Data:\\n\\n```python\\nimport pandas as pd\\nfrom flaml.automl.spark.utils import to_pandas_on_spark\\n\\n# Creating a dictionary\\ndata = {\\n \"Square_Feet\": [800, 1200, 1800, 1500, 850],\\n \"Age_Years\": [20, 15, 10, 7, 25],\\n \"Price\": [100000, 200000, 300000, 240000, 120000],\\n}\\n\\n# Creating a pandas DataFrame\\ndataframe = pd.DataFrame(data)\\nlabel = \"Price\"\\n\\n# Convert to pandas-on-spark dataframe\\npsdf = to_pandas_on_spark(dataframe)\\n```\\n\\nTo use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\\n\\nHere is an example of how to use it:\\n\\n```python\\nfrom pyspark.ml.feature import VectorAssembler\\n\\ncolumns = psdf.columns\\nfeature_cols = [col for col in columns if col != label]\\nfeaturizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\\npsdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\\n```\\n\\nLater in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\\n\\n### Estimators\\n\\n#### Model List\\n\\n- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\\n\\n#### Usage\\n\\nFirst, prepare your data in the required format as described in the previous section.\\n\\nBy including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven\\'t specified them.\\n\\nHere is an example code snippet using SparkML models in AutoML:\\n\\n```python\\nimport flaml\\n\\n# prepare your data in pandas-on-spark format as we previously mentioned\\n\\nautoml = flaml.AutoML()\\nsettings = {\\n \"time_budget\": 30,\\n \"metric\": \"r2\",\\n \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\\n \"task\": \"regression\",\\n}\\n\\nautoml.fit(\\n dataframe=psdf,\\n label=label,\\n **settings,\\n)\\n```\\n\\n[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\\n\\n## Parallel Spark Jobs\\n\\nYou can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\\n\\nPlease note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\\n\\nAll the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\\n\\n- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\\n- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\\n- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\\n\\nAn example code snippet for using parallel Spark jobs:\\n\\n```python\\nimport flaml\\n\\nautoml_experiment = flaml.AutoML()\\nautoml_settings = {\\n \"time_budget\": 30,\\n \"metric\": \"r2\",\\n \"task\": \"regression\",\\n \"n_concurrent_trials\": 2,\\n \"use_spark\": True,\\n \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\\n}\\n\\nautoml.fit(\\n dataframe=dataframe,\\n label=label,\\n **automl_settings,\\n)\\n```\\n\\n[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\\n\\n', 'role': 'assistant'}, {'content': 'The author of FLAML is Chi Wang, along with several co-authors for various publications related to FLAML.', 'role': 'user'}], summary='The author of FLAML is Chi Wang, along with several co-authors for various publications related to FLAML.', cost=({'total_cost': 0.004711, 'gpt-35-turbo': {'cost': 0.004711, 'prompt_tokens': 3110, 'completion_tokens': 23, 'total_tokens': 3133}}, {'total_cost': 0}), human_input=[])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ diff --git a/notebook/agentchat_groupchat_RAG.ipynb b/notebook/agentchat_groupchat_RAG.ipynb index 35ab96909f..1057deabf9 100644 --- a/notebook/agentchat_groupchat_RAG.ipynb +++ b/notebook/agentchat_groupchat_RAG.ipynb @@ -42,7 +42,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "LLM models: ['gpt-4-1106-preview', 'gpt-4-turbo-preview', 'gpt-4-0613', 'gpt-35-turbo-0613', 'gpt-35-turbo-1106']\n" + "LLM models: ['gpt4-1106-preview', 'gpt-35-turbo', 'gpt-35-turbo-0613']\n" ] } ], @@ -77,12 +77,23 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n" + ] + } + ], "source": [ "def termination_msg(x):\n", " return isinstance(x, dict) and \"TERMINATE\" == str(x.get(\"content\", \"\"))[-9:].upper()\n", "\n", "\n", + "llm_config = {\"config_list\": config_list, \"timeout\": 60, \"temperature\": 0.8, \"seed\": 1234}\n", + "\n", "boss = autogen.UserProxyAgent(\n", " name=\"Boss\",\n", " is_termination_msg=termination_msg,\n", @@ -96,13 +107,13 @@ " name=\"Boss_Assistant\",\n", " is_termination_msg=termination_msg,\n", " human_input_mode=\"NEVER\",\n", + " default_auto_reply=\"Reply `TERMINATE` if the task is done.\",\n", " max_consecutive_auto_reply=3,\n", " retrieve_config={\n", " \"task\": \"code\",\n", " \"docs_path\": \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md\",\n", " \"chunk_token_size\": 1000,\n", " \"model\": config_list[0][\"model\"],\n", - " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n", " \"collection_name\": \"groupchat\",\n", " \"get_or_create\": True,\n", " },\n", @@ -114,7 +125,7 @@ " name=\"Senior_Python_Engineer\",\n", " is_termination_msg=termination_msg,\n", " system_message=\"You are a senior python engineer, you provide python code to answer questions. Reply `TERMINATE` in the end when everything is done.\",\n", - " llm_config={\"config_list\": config_list, \"timeout\": 60, \"temperature\": 0},\n", + " llm_config=llm_config,\n", " description=\"Senior Python Engineer who can write code to solve problems and answer questions.\",\n", ")\n", "\n", @@ -122,7 +133,7 @@ " name=\"Product_Manager\",\n", " is_termination_msg=termination_msg,\n", " system_message=\"You are a product manager. Reply `TERMINATE` in the end when everything is done.\",\n", - " llm_config={\"config_list\": config_list, \"timeout\": 60, \"temperature\": 0},\n", + " llm_config=llm_config,\n", " description=\"Product Manager who can design and plan the project.\",\n", ")\n", "\n", @@ -130,7 +141,7 @@ " name=\"Code_Reviewer\",\n", " is_termination_msg=termination_msg,\n", " system_message=\"You are a code reviewer. Reply `TERMINATE` in the end when everything is done.\",\n", - " llm_config={\"config_list\": config_list, \"timeout\": 60, \"temperature\": 0},\n", + " llm_config=llm_config,\n", " description=\"Code Reviewer who can review the code.\",\n", ")\n", "\n", @@ -150,9 +161,7 @@ " groupchat = autogen.GroupChat(\n", " agents=[boss_aid, pm, coder, reviewer], messages=[], max_round=12, speaker_selection_method=\"round_robin\"\n", " )\n", - " manager = autogen.GroupChatManager(\n", - " groupchat=groupchat, llm_config={\"config_list\": config_list, \"timeout\": 60, \"temperature\": 0}\n", - " )\n", + " manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)\n", "\n", " # Start chatting with boss_aid as this is the user proxy agent.\n", " boss_aid.initiate_chat(\n", @@ -172,9 +181,7 @@ " speaker_selection_method=\"auto\",\n", " allow_repeat_speaker=False,\n", " )\n", - " manager = autogen.GroupChatManager(\n", - " groupchat=groupchat, llm_config={\"config_list\": config_list, \"timeout\": 60, \"temperature\": 0}\n", - " )\n", + " manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)\n", "\n", " # Start chatting with the boss as this is the user proxy agent.\n", " boss.initiate_chat(\n", @@ -226,9 +233,7 @@ " allow_repeat_speaker=False,\n", " )\n", "\n", - " manager = autogen.GroupChatManager(\n", - " groupchat=groupchat, llm_config={\"config_list\": config_list, \"timeout\": 60, \"temperature\": 0}\n", - " )\n", + " manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)\n", "\n", " # Start chatting with the boss as this is the user proxy agent.\n", " boss.initiate_chat(\n", @@ -270,49 +275,129 @@ "text": [ "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", "\n", - "To use Apache Spark for parallel training in FLAML, you need to use the `flaml.tune.run` function. Here is a sample code:\n", + "To use Spark for parallel training in FLAML (Fast and Lightweight AutoML), you would need to set up a Spark cluster and utilize the `spark` backend for joblib, which FLAML uses internally for parallel training. Here’s an example of how you might set up and use Spark with FLAML for AutoML tasks:\n", + "\n", + "Firstly, ensure that you have the Spark cluster set up and the `pyspark` and `joblib-spark` packages installed in your environment. You can install the required packages using pip if they are not already installed:\n", "\n", "```python\n", - "from flaml import tune\n", - "\n", - "# Define your training function\n", - "def training_function(config):\n", - " # your training code here\n", - " pass\n", - "\n", - "# Define your search space\n", - "search_space = {\n", - " \"lr\": tune.loguniform(1e-4, 1e-1),\n", - " \"momentum\": tune.uniform(0.1, 0.9),\n", - "}\n", - "\n", - "# Use SparkTrials for parallelization\n", - "from ray.tune import SparkTrials\n", - "\n", - "spark_trials = SparkTrials(parallelism=2)\n", - "\n", - "analysis = tune.run(\n", - " training_function,\n", - " config=search_space,\n", - " num_samples=10,\n", - " scheduler=tune.schedulers.FIFOScheduler(),\n", - " progress_reporter=tune.JupyterNotebookReporter(overwrite=True),\n", - " trial_executor=spark_trials,\n", - ")\n", - "\n", - "print(\"Best config: \", analysis.get_best_config(metric=\"accuracy\", mode=\"max\"))\n", - "\n", - "# Get a dataframe for analyzing trial results.\n", - "df = analysis.results_df\n", + "!pip install flaml pyspark joblib-spark\n", "```\n", "\n", - "In this code, `training_function` is your training function, which should take a `config` argument. This `config` argument is a dictionary that includes hyperparameters for your model. The `search_space` is a dictionary that defines the search space for your hyperparameters.\n", + "Here's a sample code snippet that demonstrates how to use FLAML with Spark for parallel training:\n", "\n", - "The `tune.run` function is used to start the hyperparameter tuning. The `config` argument is your search space, `num_samples` is the number of times to sample from the search space, and `scheduler` is the scheduler for the trials. The `trial_executor` argument is set to `spark_trials` to use Spark for parallelization.\n", + "```python\n", + "from flaml import AutoML\n", + "from pyspark.sql import SparkSession\n", + "from sklearn.datasets import load_digits\n", + "from joblibspark import register_spark\n", "\n", - "The `analysis.get_best_config` function is used to get the best hyperparameters found during the tuning. The `analysis.results_df` gives a dataframe that contains the results of all trials.\n", + "# Initialize a Spark session\n", + "spark = SparkSession.builder \\\n", + " .master(\"local[*]\") \\\n", + " .appName(\"FLAML_Spark_Example\") \\\n", + " .getOrCreate()\n", "\n", - "Please note that you need to have Apache Spark and Ray installed and properly configured in your environment to run this code.\n", + "# Register the joblib spark backend\n", + "register_spark() # This registers the backend for parallel processing\n", + "\n", + "# Load sample data\n", + "X, y = load_digits(return_X_y=True)\n", + "\n", + "# Initialize an AutoML instance\n", + "automl = AutoML()\n", + "\n", + "# Define the settings for the AutoML run\n", + "settings = {\n", + " \"time_budget\": 60, # Total running time in seconds\n", + " \"metric\": 'accuracy', # Primary metric for evaluation\n", + " \"task\": 'classification', # Task type\n", + " \"n_jobs\": -1, # Number of jobs to run in parallel (use -1 for all)\n", + " \"estimator_list\": ['lgbm', 'rf', 'xgboost'], # List of estimators to consider\n", + " \"log_file_name\": \"flaml_log.txt\", # Log file name\n", + "}\n", + "\n", + "# Run the AutoML search with Spark backend\n", + "automl.fit(X_train=X, y_train=y, **settings)\n", + "\n", + "# Output the best model and its performance\n", + "print(f\"Best ML model: {automl.model}\")\n", + "print(f\"Best ML model's accuracy: {automl.best_loss}\")\n", + "\n", + "# Stop the Spark session\n", + "spark.stop()\n", + "```\n", + "\n", + "The `register_spark()` function from `joblib-spark` is used to register the Spark backend with joblib, which is utilized for parallel training within FLAML. The `n_jobs=-1` parameter tells FLAML to use all available Spark executors for parallel training.\n", + "\n", + "Please note that the actual process of setting up a Spark cluster can be complex and might involve additional steps such as configuring Spark workers, allocating resources, and more, which are beyond the scope of this code snippet.\n", + "\n", + "If you encounter any issues or need to adjust configurations for your specific Spark setup, please refer to the Spark and FLAML documentation for more details.\n", + "\n", + "When you run the code, ensure that your Spark cluster is properly configured and accessible from your Python environment. Adjust the `.master(\"local[*]\")` to point to your Spark master's URL if you are running a cluster that is not local.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "To use Spark for parallel training in FLAML (Fast and Lightweight AutoML), you would need to set up a Spark cluster and utilize the `spark` backend for joblib, which FLAML uses internally for parallel training. Here’s an example of how you might set up and use Spark with FLAML for AutoML tasks:\n", + "\n", + "Firstly, ensure that you have the Spark cluster set up and the `pyspark` and `joblib-spark` packages installed in your environment. You can install the required packages using pip if they are not already installed:\n", + "\n", + "```python\n", + "!pip install flaml pyspark joblib-spark\n", + "```\n", + "\n", + "Here's a sample code snippet that demonstrates how to use FLAML with Spark for parallel training:\n", + "\n", + "```python\n", + "from flaml import AutoML\n", + "from pyspark.sql import SparkSession\n", + "from sklearn.datasets import load_digits\n", + "from joblibspark import register_spark\n", + "\n", + "# Initialize a Spark session\n", + "spark = SparkSession.builder \\\n", + " .master(\"local[*]\") \\\n", + " .appName(\"FLAML_Spark_Example\") \\\n", + " .getOrCreate()\n", + "\n", + "# Register the joblib spark backend\n", + "register_spark() # This registers the backend for parallel processing\n", + "\n", + "# Load sample data\n", + "X, y = load_digits(return_X_y=True)\n", + "\n", + "# Initialize an AutoML instance\n", + "automl = AutoML()\n", + "\n", + "# Define the settings for the AutoML run\n", + "settings = {\n", + " \"time_budget\": 60, # Total running time in seconds\n", + " \"metric\": 'accuracy', # Primary metric for evaluation\n", + " \"task\": 'classification', # Task type\n", + " \"n_jobs\": -1, # Number of jobs to run in parallel (use -1 for all)\n", + " \"estimator_list\": ['lgbm', 'rf', 'xgboost'], # List of estimators to consider\n", + " \"log_file_name\": \"flaml_log.txt\", # Log file name\n", + "}\n", + "\n", + "# Run the AutoML search with Spark backend\n", + "automl.fit(X_train=X, y_train=y, **settings)\n", + "\n", + "# Output the best model and its performance\n", + "print(f\"Best ML model: {automl.model}\")\n", + "print(f\"Best ML model's accuracy: {automl.best_loss}\")\n", + "\n", + "# Stop the Spark session\n", + "spark.stop()\n", + "```\n", + "\n", + "The `register_spark()` function from `joblib-spark` is used to register the Spark backend with joblib, which is utilized for parallel training within FLAML. The `n_jobs=-1` parameter tells FLAML to use all available Spark executors for parallel training.\n", + "\n", + "Please note that the actual process of setting up a Spark cluster can be complex and might involve additional steps such as configuring Spark workers, allocating resources, and more, which are beyond the scope of this code snippet.\n", + "\n", + "If you encounter any issues or need to adjust configurations for your specific Spark setup, please refer to the Spark and FLAML documentation for more details.\n", + "\n", + "When you run the code, ensure that your Spark cluster is properly configured and accessible from your Python environment. Adjust the `.master(\"local[*]\")` to point to your Spark master's URL if you are running a cluster that is not local.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n", "\n", "TERMINATE\n", "\n", @@ -335,17 +420,38 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 4, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-07 18:26:04,562 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `groupchat`.\u001b[0m\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "doc_ids: [['doc_0', 'doc_1', 'doc_122']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_122 to context.\u001b[0m\n", + "Trying to create collection.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-07 18:26:05,485 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 1 chunks.\u001b[0m\n", + "Number of requested results 3 is greater than number of elements in index 1, updating n_results = 1\n", + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "VectorDB returns doc_ids: [['bdfbc921']]\n", + "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -363,6 +469,7 @@ "Context is: # Integrate - Spark\n", "\n", "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", "- Use Spark ML estimators for AutoML.\n", "- Use Spark to run training in parallel spark jobs.\n", "\n", @@ -377,6 +484,7 @@ "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", "\n", "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", "- `index_col` is the column name to use as the index, default is None.\n", "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", "\n", @@ -385,10 +493,13 @@ "```python\n", "import pandas as pd\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "\n", "# Creating a dictionary\n", - "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", "\n", "# Creating a pandas DataFrame\n", "dataframe = pd.DataFrame(data)\n", @@ -401,8 +512,10 @@ "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", "\n", "Here is an example of how to use it:\n", + "\n", "```python\n", "from pyspark.ml.feature import VectorAssembler\n", + "\n", "columns = psdf.columns\n", "feature_cols = [col for col in columns if col != label]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", @@ -412,10 +525,13 @@ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", "\n", "### Estimators\n", + "\n", "#### Model List\n", + "\n", "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", "\n", "#### Usage\n", + "\n", "First, prepare your data in the required format as described in the previous section.\n", "\n", "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", @@ -424,6 +540,7 @@ "\n", "```python\n", "import flaml\n", + "\n", "# prepare your data in pandas-on-spark format as we previously mentioned\n", "\n", "automl = flaml.AutoML()\n", @@ -441,24 +558,25 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", "\n", "## Parallel Spark Jobs\n", + "\n", "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", "\n", "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", "\n", "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", "\n", - "\n", "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", "\n", "An example code snippet for using parallel Spark jobs:\n", + "\n", "```python\n", "import flaml\n", + "\n", "automl_experiment = flaml.AutoML()\n", "automl_settings = {\n", " \"time_budget\": 30,\n", @@ -466,7 +584,7 @@ " \"task\": \"regression\",\n", " \"n_concurrent_trials\": 2,\n", " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", "}\n", "\n", "automl.fit(\n", @@ -476,283 +594,206 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", "\n", - "2684,4/26/2011,2,0,4,17,0,2,1,1,0.68,0.6364,0.61,0.3582,521\n", - "2685,4/26/2011,2,0,4,18,0,2,1,1,0.68,0.6364,0.65,0.4478,528\n", - "2686,4/26/2011,2,0,4,19,0,2,1,1,0.64,0.6061,0.73,0.4179,328\n", - "2687,4/26/2011,2,0,4,20,0,2,1,1,0.64,0.6061,0.73,0.3582,234\n", - "2688,4/26/2011,2,0,4,21,0,2,1,1,0.62,0.5909,0.78,0.2836,195\n", - "2689,4/26/2011,2,0,4,22,0,2,1,2,0.6,0.5606,0.83,0.194,148\n", - "2690,4/26/2011,2,0,4,23,0,2,1,2,0.6,0.5606,0.83,0.2239,78\n", - "2691,4/27/2011,2,0,4,0,0,3,1,1,0.6,0.5606,0.83,0.2239,27\n", - "2692,4/27/2011,2,0,4,1,0,3,1,1,0.6,0.5606,0.83,0.2537,17\n", - "2693,4/27/2011,2,0,4,2,0,3,1,1,0.58,0.5455,0.88,0.2537,5\n", - "2694,4/27/2011,2,0,4,3,0,3,1,2,0.58,0.5455,0.88,0.2836,7\n", - "2695,4/27/2011,2,0,4,4,0,3,1,1,0.56,0.5303,0.94,0.2239,6\n", - "2696,4/27/2011,2,0,4,5,0,3,1,2,0.56,0.5303,0.94,0.2537,17\n", - "2697,4/27/2011,2,0,4,6,0,3,1,1,0.56,0.5303,0.94,0.2537,84\n", - "2698,4/27/2011,2,0,4,7,0,3,1,2,0.58,0.5455,0.88,0.2836,246\n", - "2699,4/27/2011,2,0,4,8,0,3,1,2,0.58,0.5455,0.88,0.3284,444\n", - "2700,4/27/2011,2,0,4,9,0,3,1,2,0.6,0.5455,0.88,0.4179,181\n", - "2701,4/27/2011,2,0,4,10,0,3,1,2,0.62,0.5758,0.83,0.2836,92\n", - "2702,4/27/2011,2,0,4,11,0,3,1,2,0.64,0.5909,0.78,0.2836,156\n", - "2703,4/27/2011,2,0,4,12,0,3,1,1,0.66,0.6061,0.78,0.3284,173\n", - "2704,4/27/2011,2,0,4,13,0,3,1,1,0.64,0.5909,0.78,0.2985,150\n", - "2705,4/27/2011,2,0,4,14,0,3,1,1,0.68,0.6364,0.74,0.2836,148\n", "\n", "\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", - "\n", - "To use Spark for parallel training in FLAML, you can follow these steps:\n", - "\n", - "1. Prepare your data in the required format using the `to_pandas_on_spark` function from the `flaml.automl.spark.utils` module. This function converts your data into a pandas-on-spark dataframe, which is required by Spark estimators. Here is an example code snippet:\n", - "\n", - "```python\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "\n", - "# Creating a dictionary\n", - "data = {\n", - " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]\n", - "}\n", - "\n", - "# Creating a pandas DataFrame\n", - "dataframe = pd.DataFrame(data)\n", - "label = \"Price\"\n", - "\n", - "# Convert to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "```\n", - "\n", - "2. Format your data appropriately for Spark ML models. Use the `VectorAssembler` from `pyspark.ml.feature` to merge all feature columns into a single vector column. Here is an example:\n", - "\n", - "```python\n", - "from pyspark.ml.feature import VectorAssembler\n", - "\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "```\n", - "\n", - "3. Use the Spark ML models in FLAML's AutoML. Include the models you want to try in the `estimator_list` argument to `flaml.AutoML()`. FLAML will start trying configurations for these models. Here is an example code snippet:\n", - "\n", - "```python\n", - "import flaml\n", - "\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"],\n", - " \"task\": \"regression\"\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings\n", - ")\n", - "```\n", - "\n", - "4. To enable parallel Spark jobs during parallel tuning, set the `use_spark` parameter to `True`. FLAML will dispatch your job to the distributed Spark backend using `joblib-spark`. Here is an example code snippet:\n", - "\n", - "```python\n", - "import flaml\n", - "\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings\n", - ")\n", - "```\n", - "\n", - "Please note that you should not set `use_spark` to `True` when applying AutoML and Tuning for Spark Data, as SparkML models will be used for Spark Data in AutoML and Tuning.\n", - "\n", - "Let me know if you need anything else.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "To use Spark for parallel training in FLAML, you can follow these steps:\n", - "\n", - "1. Prepare your data in the required format using the `to_pandas_on_spark` function from the `flaml.automl.spark.utils` module. This function converts your data into a pandas-on-spark dataframe, which is required by Spark estimators. Here is an example code snippet:\n", - "\n", - "```python\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "\n", - "# Creating a dictionary\n", - "data = {\n", - " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]\n", - "}\n", - "\n", - "# Creating a pandas DataFrame\n", - "dataframe = pd.DataFrame(data)\n", - "label = \"Price\"\n", - "\n", - "# Convert to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "```\n", - "\n", - "2. Format your data appropriately for Spark ML models. Use the `VectorAssembler` from `pyspark.ml.feature` to merge all feature columns into a single vector column. Here is an example:\n", - "\n", - "```python\n", - "from pyspark.ml.feature import VectorAssembler\n", - "\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "```\n", - "\n", - "3. Use the Spark ML models in FLAML's AutoML. Include the models you want to try in the `estimator_list` argument to `flaml.AutoML()`. FLAML will start trying configurations for these models. Here is an example code snippet:\n", - "\n", - "```python\n", - "import flaml\n", - "\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"],\n", - " \"task\": \"regression\"\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings\n", - ")\n", - "```\n", - "\n", - "4. To enable parallel Spark jobs during parallel tuning, set the `use_spark` parameter to `True`. FLAML will dispatch your job to the distributed Spark backend using `joblib-spark`. Here is an example code snippet:\n", - "\n", - "```python\n", - "import flaml\n", - "\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings\n", - ")\n", - "```\n", - "\n", - "Please note that you should not set `use_spark` to `True` when applying AutoML and Tuning for Spark Data, as SparkML models will be used for Spark Data in AutoML and Tuning.\n", - "\n", - "Let me know if you need anything else.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", - "\n", - "Here is the sample code to use Spark for parallel training in FLAML:\n", - "\n", - "```python\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "from pyspark.ml.feature import VectorAssembler\n", - "import flaml\n", - "\n", - "# Step 1: Prepare your data in the required format\n", - "data = {\n", - " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]\n", - "}\n", - "\n", - "dataframe = pd.DataFrame(data)\n", - "label = \"Price\"\n", - "\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "\n", - "# Step 2: Format your data for Spark ML models\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "\n", - "# Step 3: Use Spark ML models in FLAML's AutoML\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"],\n", - " \"task\": \"regression\"\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings\n", - ")\n", - "\n", - "# Step 4: Enable parallel Spark jobs during parallel tuning\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings\n", - ")\n", - "```\n", - "\n", - "Let me know if you need anything else.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n", - "\n", - "The code you provided is correct and follows the guidelines for using Spark for parallel training in FLAML. It includes the necessary steps to prepare the data, format it for Spark ML models, and use Spark ML models in FLAML's AutoML. It also demonstrates how to enable parallel Spark jobs during parallel tuning.\n", - "\n", - "Great job! You can now terminate the conversation.\n", - "\n", "--------------------------------------------------------------------------------\n", "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n", "\n", + "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", + "context provided by the user.\n", + "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", + "For code generation, you must obey the following rules:\n", + "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", + "Rule 2. You must follow the formats below to write your code:\n", + "```language\n", + "# your code\n", + "```\n", + "\n", + "User's question is: How to use spark for parallel training in FLAML? Give me sample code.\n", + "\n", + "Context is: # Integrate - Spark\n", + "\n", + "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", + "- Use Spark ML estimators for AutoML.\n", + "- Use Spark to run training in parallel spark jobs.\n", + "\n", + "## Spark ML Estimators\n", + "\n", + "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", + "\n", + "### Data\n", + "\n", + "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", + "\n", + "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", + "\n", + "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", + "- `index_col` is the column name to use as the index, default is None.\n", + "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", + "\n", + "Here is an example code snippet for Spark Data:\n", + "\n", + "```python\n", + "import pandas as pd\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "\n", + "# Creating a dictionary\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", + "\n", + "# Creating a pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"Price\"\n", + "\n", + "# Convert to pandas-on-spark dataframe\n", + "psdf = to_pandas_on_spark(dataframe)\n", + "```\n", + "\n", + "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", + "\n", + "Here is an example of how to use it:\n", + "\n", + "```python\n", + "from pyspark.ml.feature import VectorAssembler\n", + "\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", + "```\n", + "\n", + "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", + "\n", + "### Estimators\n", + "\n", + "#### Model List\n", + "\n", + "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", + "\n", + "#### Usage\n", + "\n", + "First, prepare your data in the required format as described in the previous section.\n", + "\n", + "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", + "\n", + "Here is an example code snippet using SparkML models in AutoML:\n", + "\n", + "```python\n", + "import flaml\n", + "\n", + "# prepare your data in pandas-on-spark format as we previously mentioned\n", + "\n", + "automl = flaml.AutoML()\n", + "settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", + " \"task\": \"regression\",\n", + "}\n", + "\n", + "automl.fit(\n", + " dataframe=psdf,\n", + " label=label,\n", + " **settings,\n", + ")\n", + "```\n", + "\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", + "\n", + "## Parallel Spark Jobs\n", + "\n", + "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", + "\n", + "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", + "\n", + "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", + "\n", + "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", + "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", + "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", + "\n", + "An example code snippet for using parallel Spark jobs:\n", + "\n", + "```python\n", + "import flaml\n", + "\n", + "automl_experiment = flaml.AutoML()\n", + "automl_settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"task\": \"regression\",\n", + " \"n_concurrent_trials\": 2,\n", + " \"use_spark\": True,\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + "}\n", + "\n", + "automl.fit(\n", + " dataframe=dataframe,\n", + " label=label,\n", + " **automl_settings,\n", + ")\n", + "```\n", + "\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", + "\n", "\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", "\n", + "```python\n", + "from flaml.automl import AutoML\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "from pyspark.ml.feature import VectorAssembler\n", + "import pandas as pd\n", + "\n", + "# Sample data in a dictionary\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", + "\n", + "# Convert dictionary to pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "label = \"Price\"\n", + "\n", + "# Convert pandas DataFrame to pandas-on-spark DataFrame\n", + "psdf = to_pandas_on_spark(dataframe)\n", + "\n", + "# Use VectorAssembler to merge feature columns into a single vector column\n", + "feature_cols = [col for col in psdf.columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\", label]\n", + "\n", + "# Initialize AutoML instance\n", + "automl = AutoML()\n", + "\n", + "# AutoML settings\n", + "automl_settings = {\n", + " \"time_budget\": 30, # Total running time in seconds\n", + " \"metric\": \"r2\", # Evaluation metric\n", + " \"task\": \"regression\",\n", + " \"n_concurrent_trials\": 2, # Number of concurrent Spark jobs\n", + " \"use_spark\": True, # Enable Spark for parallel training\n", + " \"force_cancel\": True, # Force cancel Spark jobs if they exceed the time budget\n", + " \"estimator_list\": [\"lgbm_spark\"] # Optional: Specific estimator to use\n", + "}\n", + "\n", + "# Run AutoML fit with pandas-on-spark dataframe\n", + "automl.fit(\n", + " dataframe=psdf,\n", + " label=label,\n", + " **automl_settings,\n", + ")\n", + "```\n", "TERMINATE\n", "\n", "--------------------------------------------------------------------------------\n" @@ -775,7 +816,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -787,28 +828,35 @@ "How to use spark for parallel training in FLAML? Give me sample code.\n", "\n", "--------------------------------------------------------------------------------\n", - "How to use spark for parallel training in FLAML? Give me sample code.\n", - "\n", - "--------------------------------------------------------------------------------\n", "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", "\n", - "\u001b[32m***** Suggested function Call: retrieve_content *****\u001b[0m\n", + "\u001b[32m***** Suggested function call: retrieve_content *****\u001b[0m\n", "Arguments: \n", - "{\n", - " \"message\": \"How to use spark for parallel training in FLAML? Give me sample code.\"\n", - "}\n", + "{\"message\":\"using Apache Spark for parallel training in FLAML with sample code\"}\n", "\u001b[32m*****************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[35m\n", - ">>>>>>>> EXECUTING FUNCTION retrieve_content...\u001b[0m\n", - "doc_ids: [['doc_0', 'doc_1', 'doc_122']]\n", - "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id doc_122 to context.\u001b[0m\n", + ">>>>>>>> EXECUTING FUNCTION retrieve_content...\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Number of requested results 3 is greater than number of elements in index 1, updating n_results = 1\n", + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "VectorDB returns doc_ids: [['bdfbc921']]\n", + "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", "\u001b[33mBoss\u001b[0m (to chat_manager):\n", "\n", - "\u001b[32m***** Response from calling function \"retrieve_content\" *****\u001b[0m\n", + "\u001b[32m***** Response from calling function (retrieve_content) *****\u001b[0m\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", "context provided by the user.\n", "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", @@ -819,11 +867,12 @@ "# your code\n", "```\n", "\n", - "User's question is: How to use spark for parallel training in FLAML? Give me sample code.\n", + "User's question is: using Apache Spark for parallel training in FLAML with sample code\n", "\n", "Context is: # Integrate - Spark\n", "\n", "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", "- Use Spark ML estimators for AutoML.\n", "- Use Spark to run training in parallel spark jobs.\n", "\n", @@ -838,6 +887,7 @@ "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", "\n", "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", "- `index_col` is the column name to use as the index, default is None.\n", "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", "\n", @@ -846,10 +896,13 @@ "```python\n", "import pandas as pd\n", "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "\n", "# Creating a dictionary\n", - "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", "\n", "# Creating a pandas DataFrame\n", "dataframe = pd.DataFrame(data)\n", @@ -862,8 +915,10 @@ "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", "\n", "Here is an example of how to use it:\n", + "\n", "```python\n", "from pyspark.ml.feature import VectorAssembler\n", + "\n", "columns = psdf.columns\n", "feature_cols = [col for col in columns if col != label]\n", "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", @@ -873,10 +928,13 @@ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", "\n", "### Estimators\n", + "\n", "#### Model List\n", + "\n", "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", "\n", "#### Usage\n", + "\n", "First, prepare your data in the required format as described in the previous section.\n", "\n", "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", @@ -885,6 +943,7 @@ "\n", "```python\n", "import flaml\n", + "\n", "# prepare your data in pandas-on-spark format as we previously mentioned\n", "\n", "automl = flaml.AutoML()\n", @@ -902,24 +961,25 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", "\n", "## Parallel Spark Jobs\n", + "\n", "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", "\n", "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", "\n", "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", "\n", - "\n", "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", "\n", "An example code snippet for using parallel Spark jobs:\n", + "\n", "```python\n", "import flaml\n", + "\n", "automl_experiment = flaml.AutoML()\n", "automl_settings = {\n", " \"time_budget\": 30,\n", @@ -927,7 +987,7 @@ " \"task\": \"regression\",\n", " \"n_concurrent_trials\": 2,\n", " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", "}\n", "\n", "automl.fit(\n", @@ -937,41 +997,50 @@ ")\n", "```\n", "\n", - "\n", "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", "\n", - "2684,4/26/2011,2,0,4,17,0,2,1,1,0.68,0.6364,0.61,0.3582,521\n", - "2685,4/26/2011,2,0,4,18,0,2,1,1,0.68,0.6364,0.65,0.4478,528\n", - "2686,4/26/2011,2,0,4,19,0,2,1,1,0.64,0.6061,0.73,0.4179,328\n", - "2687,4/26/2011,2,0,4,20,0,2,1,1,0.64,0.6061,0.73,0.3582,234\n", - "2688,4/26/2011,2,0,4,21,0,2,1,1,0.62,0.5909,0.78,0.2836,195\n", - "2689,4/26/2011,2,0,4,22,0,2,1,2,0.6,0.5606,0.83,0.194,148\n", - "2690,4/26/2011,2,0,4,23,0,2,1,2,0.6,0.5606,0.83,0.2239,78\n", - "2691,4/27/2011,2,0,4,0,0,3,1,1,0.6,0.5606,0.83,0.2239,27\n", - "2692,4/27/2011,2,0,4,1,0,3,1,1,0.6,0.5606,0.83,0.2537,17\n", - "2693,4/27/2011,2,0,4,2,0,3,1,1,0.58,0.5455,0.88,0.2537,5\n", - "2694,4/27/2011,2,0,4,3,0,3,1,2,0.58,0.5455,0.88,0.2836,7\n", - "2695,4/27/2011,2,0,4,4,0,3,1,1,0.56,0.5303,0.94,0.2239,6\n", - "2696,4/27/2011,2,0,4,5,0,3,1,2,0.56,0.5303,0.94,0.2537,17\n", - "2697,4/27/2011,2,0,4,6,0,3,1,1,0.56,0.5303,0.94,0.2537,84\n", - "2698,4/27/2011,2,0,4,7,0,3,1,2,0.58,0.5455,0.88,0.2836,246\n", - "2699,4/27/2011,2,0,4,8,0,3,1,2,0.58,0.5455,0.88,0.3284,444\n", - "2700,4/27/2011,2,0,4,9,0,3,1,2,0.6,0.5455,0.88,0.4179,181\n", - "2701,4/27/2011,2,0,4,10,0,3,1,2,0.62,0.5758,0.83,0.2836,92\n", - "2702,4/27/2011,2,0,4,11,0,3,1,2,0.64,0.5909,0.78,0.2836,156\n", - "2703,4/27/2011,2,0,4,12,0,3,1,1,0.66,0.6061,0.78,0.3284,173\n", - "2704,4/27/2011,2,0,4,13,0,3,1,1,0.64,0.5909,0.78,0.2985,150\n", - "2705,4/27/2011,2,0,4,14,0,3,1,1,0.68,0.6364,0.74,0.2836,148\n", - "\n", "\n", "\u001b[32m*************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", + "\u001b[33mBoss\u001b[0m (to chat_manager):\n", "\n", - "To use Spark for parallel training in FLAML, you can follow these steps:\n", + "\u001b[32m***** Response from calling function (retrieve_content) *****\u001b[0m\n", + "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", + "context provided by the user.\n", + "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", + "For code generation, you must obey the following rules:\n", + "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", + "Rule 2. You must follow the formats below to write your code:\n", + "```language\n", + "# your code\n", + "```\n", "\n", - "1. Prepare your data in the required format using Spark data. You can use the `to_pandas_on_spark` function from the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark dataframe.\n", + "User's question is: using Apache Spark for parallel training in FLAML with sample code\n", + "\n", + "Context is: # Integrate - Spark\n", + "\n", + "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", + "\n", + "- Use Spark ML estimators for AutoML.\n", + "- Use Spark to run training in parallel spark jobs.\n", + "\n", + "## Spark ML Estimators\n", + "\n", + "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", + "\n", + "### Data\n", + "\n", + "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", + "\n", + "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", + "\n", + "This function also accepts optional arguments `index_col` and `default_index_type`.\n", + "\n", + "- `index_col` is the column name to use as the index, default is None.\n", + "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", + "\n", + "Here is an example code snippet for Spark Data:\n", "\n", "```python\n", "import pandas as pd\n", @@ -981,7 +1050,7 @@ "data = {\n", " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000]\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", "}\n", "\n", "# Creating a pandas DataFrame\n", @@ -992,16 +1061,45 @@ "psdf = to_pandas_on_spark(dataframe)\n", "```\n", "\n", - "2. Use the Spark ML estimators provided by FLAML. You can include the models you want to try in the `estimator_list` argument of the `flaml.AutoML` class. FLAML will start trying configurations for these models.\n", + "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", + "\n", + "Here is an example of how to use it:\n", + "\n", + "```python\n", + "from pyspark.ml.feature import VectorAssembler\n", + "\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", + "```\n", + "\n", + "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", + "\n", + "### Estimators\n", + "\n", + "#### Model List\n", + "\n", + "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", + "\n", + "#### Usage\n", + "\n", + "First, prepare your data in the required format as described in the previous section.\n", + "\n", + "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", + "\n", + "Here is an example code snippet using SparkML models in AutoML:\n", "\n", "```python\n", "import flaml\n", "\n", + "# prepare your data in pandas-on-spark format as we previously mentioned\n", + "\n", "automl = flaml.AutoML()\n", "settings = {\n", " \"time_budget\": 30,\n", " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"], # Optional: specify the Spark estimator\n", + " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", " \"task\": \"regression\",\n", "}\n", "\n", @@ -1012,22 +1110,109 @@ ")\n", "```\n", "\n", - "3. Enable parallel Spark jobs by setting the `use_spark` parameter to `True` in the `fit` method. This will dispatch the job to the distributed Spark backend using `joblib-spark`.\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", + "\n", + "## Parallel Spark Jobs\n", + "\n", + "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", + "\n", + "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", + "\n", + "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", + "\n", + "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", + "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", + "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", + "\n", + "An example code snippet for using parallel Spark jobs:\n", "\n", "```python\n", + "import flaml\n", + "\n", + "automl_experiment = flaml.AutoML()\n", + "automl_settings = {\n", + " \"time_budget\": 30,\n", + " \"metric\": \"r2\",\n", + " \"task\": \"regression\",\n", + " \"n_concurrent_trials\": 2,\n", + " \"use_spark\": True,\n", + " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", + "}\n", + "\n", "automl.fit(\n", - " dataframe=psdf,\n", + " dataframe=dataframe,\n", " label=label,\n", - " use_spark=True,\n", + " **automl_settings,\n", ")\n", "```\n", "\n", - "Note: Make sure you have Spark installed and configured properly before running the code.\n", + "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", "\n", - "Please let me know if you need any further assistance.\n", + "\n", + "\u001b[32m*************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n", + "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n", + "\n", + "To use Apache Spark for parallel training in FLAML, you can follow these steps:\n", + "\n", + "1. Ensure your data is in the required pandas-on-spark format.\n", + "2. Use Spark ML estimators by including them in the `estimator_list`.\n", + "3. Set `use_spark` to `True` for parallel tuning.\n", + "\n", + "Here's a sample code demonstrating how to use Spark for parallel training in FLAML:\n", + "\n", + "```python\n", + "import flaml\n", + "from flaml.automl.spark.utils import to_pandas_on_spark\n", + "import pandas as pd\n", + "from pyspark.ml.feature import VectorAssembler\n", + "\n", + "# Sample data in a pandas DataFrame\n", + "data = {\n", + " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", + " \"Age_Years\": [20, 15, 10, 7, 25],\n", + " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", + "}\n", + "label = \"Price\"\n", + "\n", + "# Creating a pandas DataFrame\n", + "dataframe = pd.DataFrame(data)\n", + "\n", + "# Convert to pandas-on-spark dataframe\n", + "psdf = to_pandas_on_spark(dataframe)\n", + "\n", + "# Prepare features using VectorAssembler\n", + "columns = psdf.columns\n", + "feature_cols = [col for col in columns if col != label]\n", + "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", + "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", + "\n", + "# Initialize AutoML\n", + "automl = flaml.AutoML()\n", + "\n", + "# Configure settings for AutoML\n", + "settings = {\n", + " \"time_budget\": 30, # time budget in seconds\n", + " \"metric\": \"r2\",\n", + " \"estimator_list\": [\"lgbm_spark\"], # using Spark ML estimators\n", + " \"task\": \"regression\",\n", + " \"n_concurrent_trials\": 2, # number of parallel trials\n", + " \"use_spark\": True, # enable parallel training using Spark\n", + " \"force_cancel\": True, # force cancel Spark jobs if time_budget is exceeded\n", + "}\n", + "\n", + "# Start the training\n", + "automl.fit(dataframe=psdf, label=label, **settings)\n", + "```\n", + "\n", + "In this code snippet:\n", + "- The `to_pandas_on_spark` function is used to convert the pandas DataFrame to a pandas-on-spark DataFrame.\n", + "- `VectorAssembler` is used to transform feature columns into a single vector column.\n", + "- The `AutoML` object is created, and settings are configured for the AutoML run, including setting `use_spark` to `True` for parallel training.\n", + "- The `fit` method is called to start the automated machine learning process.\n", + "\n", + "By using these settings, FLAML will train the models in parallel using Spark, which can accelerate the training process on large models and datasets.\n", "\n", "TERMINATE\n", "\n", diff --git a/notebook/agentchat_qdrant_RetrieveChat.ipynb b/notebook/agentchat_qdrant_RetrieveChat.ipynb index 4a040a5f49..43272e2793 100644 --- a/notebook/agentchat_qdrant_RetrieveChat.ipynb +++ b/notebook/agentchat_qdrant_RetrieveChat.ipynb @@ -31,158 +31,166 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: pyautogen>=0.2.3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen[retrievechat]>=0.2.3) (0.2.3)\n", - "Requirement already satisfied: flaml[automl] in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (2.1.1)\n", - "Requirement already satisfied: qdrant_client[fastembed] in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (1.7.0)\n", - "Requirement already satisfied: diskcache in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (5.6.3)\n", - "Requirement already satisfied: openai>=1.3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (1.6.1)\n", - "Requirement already satisfied: pydantic<3,>=1.10 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2.5.3)\n", - "Requirement already satisfied: python-dotenv in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (1.0.0)\n", - "Requirement already satisfied: termcolor in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2.4.0)\n", - "Requirement already satisfied: tiktoken in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (0.5.2)\n", - "Requirement already satisfied: NumPy>=1.17.0rc1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from flaml[automl]) (1.26.2)\n", - "Requirement already satisfied: lightgbm>=2.3.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from flaml[automl]) (4.2.0)\n", - "Requirement already satisfied: xgboost>=0.90 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from flaml[automl]) (2.0.3)\n", - "Requirement already satisfied: scipy>=1.4.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from flaml[automl]) (1.11.4)\n", - "Requirement already satisfied: pandas>=1.1.4 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from flaml[automl]) (2.1.4)\n", - "Requirement already satisfied: scikit-learn>=0.24 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from flaml[automl]) (1.3.2)\n", - "Requirement already satisfied: fastembed==0.1.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from qdrant_client[fastembed]) (0.1.1)\n", - "Requirement already satisfied: grpcio>=1.41.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from qdrant_client[fastembed]) (1.60.0)\n", - "Requirement already satisfied: grpcio-tools>=1.41.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from qdrant_client[fastembed]) (1.60.0)\n", - "Requirement already satisfied: httpx>=0.14.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpx[http2]>=0.14.0->qdrant_client[fastembed]) (0.26.0)\n", - "Requirement already satisfied: portalocker<3.0.0,>=2.7.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from qdrant_client[fastembed]) (2.8.2)\n", - "Requirement already satisfied: urllib3<2.0.0,>=1.26.14 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from qdrant_client[fastembed]) (1.26.18)\n", - "Requirement already satisfied: onnx<2.0,>=1.11 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (1.15.0)\n", - "Requirement already satisfied: onnxruntime<2.0,>=1.15 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (1.16.3)\n", - "Requirement already satisfied: requests<3.0,>=2.31 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (2.31.0)\n", - "Requirement already satisfied: tokenizers<0.14,>=0.13 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (0.13.3)\n", - "Requirement already satisfied: tqdm<5.0,>=4.65 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (4.66.1)\n", - "Requirement already satisfied: chromadb in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen[retrievechat]>=0.2.3) (0.4.21)\n", - "Requirement already satisfied: ipython in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen[retrievechat]>=0.2.3) (8.19.0)\n", - "Requirement already satisfied: pypdf in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen[retrievechat]>=0.2.3) (3.17.4)\n", - "Requirement already satisfied: sentence-transformers in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyautogen[retrievechat]>=0.2.3) (2.2.2)\n", - "Requirement already satisfied: protobuf<5.0dev,>=4.21.6 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from grpcio-tools>=1.41.0->qdrant_client[fastembed]) (4.25.1)\n", - "Requirement already satisfied: setuptools in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from grpcio-tools>=1.41.0->qdrant_client[fastembed]) (65.5.0)\n", - "Requirement already satisfied: anyio in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (4.2.0)\n", - "Requirement already satisfied: certifi in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (2023.11.17)\n", - "Requirement already satisfied: httpcore==1.* in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (1.0.2)\n", - "Requirement already satisfied: idna in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (3.6)\n", - "Requirement already satisfied: sniffio in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (1.3.0)\n", - "Requirement already satisfied: h11<0.15,>=0.13 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (0.14.0)\n", - "Requirement already satisfied: h2<5,>=3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from httpx[http2]>=0.14.0->qdrant_client[fastembed]) (4.1.0)\n", - "Requirement already satisfied: distro<2,>=1.7.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from openai>=1.3->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (1.9.0)\n", - "Requirement already satisfied: typing-extensions<5,>=4.7 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from openai>=1.3->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (4.9.0)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pandas>=1.1.4->flaml[automl]) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pandas>=1.1.4->flaml[automl]) (2023.3.post1)\n", - "Requirement already satisfied: tzdata>=2022.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pandas>=1.1.4->flaml[automl]) (2023.4)\n", - "Requirement already satisfied: annotated-types>=0.4.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pydantic<3,>=1.10->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (0.6.0)\n", - "Requirement already satisfied: pydantic-core==2.14.6 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pydantic<3,>=1.10->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2.14.6)\n", - "Requirement already satisfied: joblib>=1.1.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from scikit-learn>=0.24->flaml[automl]) (1.3.2)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from scikit-learn>=0.24->flaml[automl]) (3.2.0)\n", - "Requirement already satisfied: chroma-hnswlib==0.7.3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.7.3)\n", - "Requirement already satisfied: fastapi>=0.95.2 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.108.0)\n", - "Requirement already satisfied: uvicorn>=0.18.3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.25.0)\n", - "Requirement already satisfied: posthog>=2.4.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (3.1.0)\n", - "Requirement already satisfied: pulsar-client>=3.1.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (3.3.0)\n", - "Requirement already satisfied: opentelemetry-api>=1.2.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (1.22.0)\n", - "Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc>=1.2.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (1.22.0)\n", - "Requirement already satisfied: opentelemetry-instrumentation-fastapi>=0.41b0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.43b0)\n", - "Requirement already satisfied: opentelemetry-sdk>=1.2.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (1.22.0)\n", - "Requirement already satisfied: pypika>=0.48.9 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.48.9)\n", - "Requirement already satisfied: overrides>=7.3.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (7.4.0)\n", - "Requirement already satisfied: importlib-resources in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (6.1.1)\n", - "Requirement already satisfied: bcrypt>=4.0.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (4.1.2)\n", - "Requirement already satisfied: typer>=0.9.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.9.0)\n", - "Requirement already satisfied: kubernetes>=28.1.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (28.1.0)\n", - "Requirement already satisfied: tenacity>=8.2.3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (8.2.3)\n", - "Requirement already satisfied: PyYAML>=6.0.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (6.0.1)\n", - "Requirement already satisfied: mmh3>=4.0.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (4.0.1)\n", - "Requirement already satisfied: decorator in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (5.1.1)\n", - "Requirement already satisfied: jedi>=0.16 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (0.19.1)\n", - "Requirement already satisfied: matplotlib-inline in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (0.1.6)\n", - "Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (3.0.43)\n", - "Requirement already satisfied: pygments>=2.4.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (2.17.2)\n", - "Requirement already satisfied: stack-data in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (0.6.3)\n", - "Requirement already satisfied: traitlets>=5 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (5.14.1)\n", - "Requirement already satisfied: pexpect>4.3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (4.9.0)\n", - "Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (4.33.3)\n", - "Requirement already satisfied: torch>=1.6.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.1.2)\n", - "Requirement already satisfied: torchvision in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (0.16.2)\n", - "Requirement already satisfied: nltk in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.8.1)\n", - "Requirement already satisfied: sentencepiece in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (0.1.99)\n", - "Requirement already satisfied: huggingface-hub>=0.4.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (0.20.1)\n", - "Requirement already satisfied: regex>=2022.1.18 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from tiktoken->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2023.12.25)\n", - "Requirement already satisfied: starlette<0.33.0,>=0.29.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from fastapi>=0.95.2->chromadb->pyautogen[retrievechat]>=0.2.3) (0.32.0.post1)\n", - "Requirement already satisfied: hyperframe<7,>=6.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from h2<5,>=3->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (6.0.1)\n", - "Requirement already satisfied: hpack<5,>=4.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from h2<5,>=3->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (4.0.0)\n", - "Requirement already satisfied: filelock in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.13.1)\n", - "Requirement already satisfied: fsspec>=2023.5.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2023.12.2)\n", - "Requirement already satisfied: packaging>=20.9 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (23.2)\n", - "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from jedi>=0.16->ipython->pyautogen[retrievechat]>=0.2.3) (0.8.3)\n", - "Requirement already satisfied: six>=1.9.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.16.0)\n", - "Requirement already satisfied: google-auth>=1.0.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (2.25.2)\n", - "Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.7.0)\n", - "Requirement already satisfied: requests-oauthlib in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.3.1)\n", - "Requirement already satisfied: oauthlib>=3.2.2 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (3.2.2)\n", - "Requirement already satisfied: coloredlogs in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (15.0.1)\n", - "Requirement already satisfied: flatbuffers in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (23.5.26)\n", - "Requirement already satisfied: sympy in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (1.12)\n", - "Requirement already satisfied: deprecated>=1.2.6 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-api>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.2.14)\n", - "Requirement already satisfied: importlib-metadata<7.0,>=6.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-api>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (6.11.0)\n", - "Requirement already satisfied: backoff<3.0.0,>=1.10.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (2.2.1)\n", - "Requirement already satisfied: googleapis-common-protos~=1.52 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.62.0)\n", - "Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.22.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.22.0)\n", - "Requirement already satisfied: opentelemetry-proto==1.22.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.22.0)\n", - "Requirement already satisfied: opentelemetry-instrumentation-asgi==0.43b0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.43b0)\n", - "Requirement already satisfied: opentelemetry-instrumentation==0.43b0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.43b0)\n", - "Requirement already satisfied: opentelemetry-semantic-conventions==0.43b0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.43b0)\n", - "Requirement already satisfied: opentelemetry-util-http==0.43b0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.43b0)\n", - "Requirement already satisfied: wrapt<2.0.0,>=1.0.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-instrumentation==0.43b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.16.0)\n", - "Requirement already satisfied: asgiref~=3.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from opentelemetry-instrumentation-asgi==0.43b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (3.7.2)\n", - "Requirement already satisfied: ptyprocess>=0.5 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pexpect>4.3->ipython->pyautogen[retrievechat]>=0.2.3) (0.7.0)\n", - "Requirement already satisfied: monotonic>=1.5 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from posthog>=2.4.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.6)\n", - "Requirement already satisfied: wcwidth in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython->pyautogen[retrievechat]>=0.2.3) (0.2.12)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from requests<3.0,>=2.31->fastembed==0.1.1->qdrant_client[fastembed]) (3.3.2)\n", - "Requirement already satisfied: networkx in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.2.1)\n", - "Requirement already satisfied: jinja2 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.1.2)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (8.9.2.26)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.3.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (11.0.2.54)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (10.3.2.106)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (11.4.5.107)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.0.106)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.18.1)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", - "Requirement already satisfied: triton==2.1.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.1.0)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.3.101)\n", - "Requirement already satisfied: safetensors>=0.3.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (0.4.1)\n", - "Requirement already satisfied: click<9.0.0,>=7.1.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from typer>=0.9.0->chromadb->pyautogen[retrievechat]>=0.2.3) (8.1.7)\n", - "Requirement already satisfied: httptools>=0.5.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.6.1)\n", - "Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.19.0)\n", - "Requirement already satisfied: watchfiles>=0.13 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.21.0)\n", - "Requirement already satisfied: websockets>=10.4 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (12.0)\n", - "Requirement already satisfied: executing>=1.2.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from stack-data->ipython->pyautogen[retrievechat]>=0.2.3) (2.0.1)\n", - "Requirement already satisfied: asttokens>=2.1.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from stack-data->ipython->pyautogen[retrievechat]>=0.2.3) (2.4.1)\n", - "Requirement already satisfied: pure-eval in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from stack-data->ipython->pyautogen[retrievechat]>=0.2.3) (0.2.2)\n", - "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from torchvision->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (10.2.0)\n", - "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (5.3.2)\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.3.0)\n", - "Requirement already satisfied: rsa<5,>=3.1.4 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (4.9)\n", - "Requirement already satisfied: zipp>=0.5 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from importlib-metadata<7.0,>=6.0->opentelemetry-api>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (3.17.0)\n", - "Requirement already satisfied: humanfriendly>=9.1 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from coloredlogs->onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (10.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.1.3)\n", - "Requirement already satisfied: mpmath>=0.19 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from sympy->onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (1.3.0)\n", - "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /workspaces/autogen/.venv-3.11/lib/python3.11/site-packages (from pyasn1-modules>=0.2.1->google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.5.1)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "Requirement already satisfied: pyautogen>=0.2.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen[retrievechat]>=0.2.3) (0.2.21)\n", + "Requirement already satisfied: flaml[automl] in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (2.1.2)\n", + "Requirement already satisfied: qdrant_client[fastembed] in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (1.6.4)\n", + "Requirement already satisfied: openai>=1.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (1.12.0)\n", + "Requirement already satisfied: diskcache in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (5.6.3)\n", + "Requirement already satisfied: termcolor in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2.3.0)\n", + "Requirement already satisfied: numpy<2,>=1.17.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (1.26.4)\n", + "Requirement already satisfied: python-dotenv in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (1.0.0)\n", + "Requirement already satisfied: tiktoken in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (0.5.1)\n", + "Requirement already satisfied: pydantic!=2.6.0,<3,>=1.10 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2.6.4)\n", + "Requirement already satisfied: docker in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (7.0.0)\n", + "Requirement already satisfied: lightgbm>=2.3.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from flaml[automl]) (4.1.0)\n", + "Requirement already satisfied: xgboost>=0.90 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from flaml[automl]) (2.0.1)\n", + "Requirement already satisfied: scipy>=1.4.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from flaml[automl]) (1.10.1)\n", + "Requirement already satisfied: pandas>=1.1.4 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from flaml[automl]) (2.2.0)\n", + "Requirement already satisfied: scikit-learn>=0.24 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from flaml[automl]) (1.3.2)\n", + "Requirement already satisfied: grpcio>=1.41.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from qdrant_client[fastembed]) (1.60.0)\n", + "Requirement already satisfied: grpcio-tools>=1.41.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from qdrant_client[fastembed]) (1.59.2)\n", + "Requirement already satisfied: httpx>=0.14.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from httpx[http2]>=0.14.0->qdrant_client[fastembed]) (0.25.1)\n", + "Requirement already satisfied: portalocker<3.0.0,>=2.7.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from qdrant_client[fastembed]) (2.8.2)\n", + "Requirement already satisfied: urllib3<2.0.0,>=1.26.14 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from qdrant_client[fastembed]) (1.26.18)\n", + "Requirement already satisfied: fastembed==0.1.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from qdrant_client[fastembed]) (0.1.1)\n", + "Requirement already satisfied: onnx<2.0,>=1.11 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (1.15.0)\n", + "Requirement already satisfied: onnxruntime<2.0,>=1.15 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (1.15.1)\n", + "Requirement already satisfied: requests<3.0,>=2.31 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (2.31.0)\n", + "Requirement already satisfied: tokenizers<0.14,>=0.13 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (0.13.3)\n", + "Requirement already satisfied: tqdm<5.0,>=4.65 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from fastembed==0.1.1->qdrant_client[fastembed]) (4.66.2)\n", + "Requirement already satisfied: chromadb in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen[retrievechat]>=0.2.3) (0.4.22)\n", + "Requirement already satisfied: sentence-transformers in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen[retrievechat]>=0.2.3) (2.3.1)\n", + "Requirement already satisfied: pypdf in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen[retrievechat]>=0.2.3) (4.0.1)\n", + "Requirement already satisfied: ipython in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyautogen[retrievechat]>=0.2.3) (8.17.2)\n", + "Requirement already satisfied: protobuf<5.0dev,>=4.21.6 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from grpcio-tools>=1.41.0->qdrant_client[fastembed]) (4.23.4)\n", + "Requirement already satisfied: setuptools in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from grpcio-tools>=1.41.0->qdrant_client[fastembed]) (68.2.2)\n", + "Requirement already satisfied: anyio in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (3.7.1)\n", + "Requirement already satisfied: certifi in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (2024.2.2)\n", + "Requirement already satisfied: httpcore in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (1.0.1)\n", + "Requirement already satisfied: idna in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (3.6)\n", + "Requirement already satisfied: sniffio in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from httpx>=0.14.0->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (1.3.0)\n", + "Requirement already satisfied: h2<5,>=3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from httpx[http2]>=0.14.0->qdrant_client[fastembed]) (4.1.0)\n", + "Requirement already satisfied: distro<2,>=1.7.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from openai>=1.3->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (1.8.0)\n", + "Requirement already satisfied: typing-extensions<5,>=4.7 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from openai>=1.3->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (4.9.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pandas>=1.1.4->flaml[automl]) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pandas>=1.1.4->flaml[automl]) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pandas>=1.1.4->flaml[automl]) (2024.1)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pydantic!=2.6.0,<3,>=1.10->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (0.6.0)\n", + "Requirement already satisfied: pydantic-core==2.16.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pydantic!=2.6.0,<3,>=1.10->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2.16.3)\n", + "Requirement already satisfied: joblib>=1.1.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from scikit-learn>=0.24->flaml[automl]) (1.3.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from scikit-learn>=0.24->flaml[automl]) (3.2.0)\n", + "Requirement already satisfied: build>=1.0.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (1.0.3)\n", + "Requirement already satisfied: chroma-hnswlib==0.7.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.7.3)\n", + "Requirement already satisfied: fastapi>=0.95.2 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.104.1)\n", + "Requirement already satisfied: uvicorn>=0.18.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.24.0)\n", + "Requirement already satisfied: posthog>=2.4.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (3.0.2)\n", + "Requirement already satisfied: pulsar-client>=3.1.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (3.3.0)\n", + "Requirement already satisfied: opentelemetry-api>=1.2.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (1.20.0)\n", + "Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc>=1.2.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (1.20.0)\n", + "Requirement already satisfied: opentelemetry-instrumentation-fastapi>=0.41b0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.41b0)\n", + "Requirement already satisfied: opentelemetry-sdk>=1.2.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (1.20.0)\n", + "Requirement already satisfied: pypika>=0.48.9 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.48.9)\n", + "Requirement already satisfied: overrides>=7.3.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (7.4.0)\n", + "Requirement already satisfied: importlib-resources in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (6.1.1)\n", + "Requirement already satisfied: bcrypt>=4.0.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (4.0.1)\n", + "Requirement already satisfied: typer>=0.9.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (0.9.0)\n", + "Requirement already satisfied: kubernetes>=28.1.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (28.1.0)\n", + "Requirement already satisfied: tenacity>=8.2.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (8.2.3)\n", + "Requirement already satisfied: PyYAML>=6.0.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (6.0.1)\n", + "Requirement already satisfied: mmh3>=4.0.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from chromadb->pyautogen[retrievechat]>=0.2.3) (4.0.1)\n", + "Requirement already satisfied: packaging>=14.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from docker->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (23.2)\n", + "Requirement already satisfied: decorator in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (5.1.1)\n", + "Requirement already satisfied: jedi>=0.16 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (0.19.1)\n", + "Requirement already satisfied: matplotlib-inline in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (0.1.6)\n", + "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (3.0.39)\n", + "Requirement already satisfied: pygments>=2.4.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (2.16.1)\n", + "Requirement already satisfied: stack-data in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (0.6.3)\n", + "Requirement already satisfied: traitlets>=5 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (5.14.2)\n", + "Requirement already satisfied: exceptiongroup in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (1.1.3)\n", + "Requirement already satisfied: pexpect>4.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from ipython->pyautogen[retrievechat]>=0.2.3) (4.8.0)\n", + "Requirement already satisfied: transformers<5.0.0,>=4.32.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (4.33.3)\n", + "Requirement already satisfied: torch>=1.11.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.2.0)\n", + "Requirement already satisfied: nltk in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.8.1)\n", + "Requirement already satisfied: sentencepiece in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (0.1.99)\n", + "Requirement already satisfied: huggingface-hub>=0.15.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (0.20.3)\n", + "Requirement already satisfied: Pillow in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from sentence-transformers->pyautogen[retrievechat]>=0.2.3) (10.2.0)\n", + "Requirement already satisfied: regex>=2022.1.18 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from tiktoken->pyautogen>=0.2.3->pyautogen[retrievechat]>=0.2.3) (2023.12.25)\n", + "Requirement already satisfied: pyproject_hooks in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from build>=1.0.3->chromadb->pyautogen[retrievechat]>=0.2.3) (1.0.0)\n", + "Requirement already satisfied: tomli>=1.1.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from build>=1.0.3->chromadb->pyautogen[retrievechat]>=0.2.3) (2.0.1)\n", + "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from fastapi>=0.95.2->chromadb->pyautogen[retrievechat]>=0.2.3) (0.27.0)\n", + "Requirement already satisfied: hyperframe<7,>=6.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from h2<5,>=3->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (6.0.1)\n", + "Requirement already satisfied: hpack<5,>=4.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from h2<5,>=3->httpx[http2]>=0.14.0->qdrant_client[fastembed]) (4.0.0)\n", + "Requirement already satisfied: filelock in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from huggingface-hub>=0.15.1->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.13.1)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from huggingface-hub>=0.15.1->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2024.2.0)\n", + "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from jedi>=0.16->ipython->pyautogen[retrievechat]>=0.2.3) (0.8.3)\n", + "Requirement already satisfied: six>=1.9.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.16.0)\n", + "Requirement already satisfied: google-auth>=1.0.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (2.23.4)\n", + "Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.6.4)\n", + "Requirement already satisfied: requests-oauthlib in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.3.1)\n", + "Requirement already satisfied: oauthlib>=3.2.2 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (3.2.2)\n", + "Requirement already satisfied: coloredlogs in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (15.0.1)\n", + "Requirement already satisfied: flatbuffers in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (23.5.26)\n", + "Requirement already satisfied: sympy in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (1.12)\n", + "Requirement already satisfied: deprecated>=1.2.6 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-api>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.2.14)\n", + "Requirement already satisfied: importlib-metadata<7.0,>=6.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-api>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (6.11.0)\n", + "Requirement already satisfied: backoff<3.0.0,>=1.10.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (2.2.1)\n", + "Requirement already satisfied: googleapis-common-protos~=1.52 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.61.0)\n", + "Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.20.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.20.0)\n", + "Requirement already satisfied: opentelemetry-proto==1.20.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.20.0)\n", + "Requirement already satisfied: opentelemetry-instrumentation-asgi==0.41b0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.41b0)\n", + "Requirement already satisfied: opentelemetry-instrumentation==0.41b0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.41b0)\n", + "Requirement already satisfied: opentelemetry-semantic-conventions==0.41b0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.41b0)\n", + "Requirement already satisfied: opentelemetry-util-http==0.41b0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.41b0)\n", + "Requirement already satisfied: wrapt<2.0.0,>=1.0.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-instrumentation==0.41b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.16.0)\n", + "Requirement already satisfied: asgiref~=3.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from opentelemetry-instrumentation-asgi==0.41b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb->pyautogen[retrievechat]>=0.2.3) (3.7.2)\n", + "Requirement already satisfied: ptyprocess>=0.5 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pexpect>4.3->ipython->pyautogen[retrievechat]>=0.2.3) (0.7.0)\n", + "Requirement already satisfied: monotonic>=1.5 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from posthog>=2.4.0->chromadb->pyautogen[retrievechat]>=0.2.3) (1.6)\n", + "Requirement already satisfied: wcwidth in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython->pyautogen[retrievechat]>=0.2.3) (0.2.9)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from requests<3.0,>=2.31->fastembed==0.1.1->qdrant_client[fastembed]) (3.3.2)\n", + "Requirement already satisfied: networkx in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (3.1.3)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.19.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.1.105)\n", + "Requirement already satisfied: triton==2.2.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.2.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (12.3.52)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from transformers<5.0.0,>=4.32.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (0.3.2)\n", + "Requirement already satisfied: click<9.0.0,>=7.1.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from typer>=0.9.0->chromadb->pyautogen[retrievechat]>=0.2.3) (8.1.7)\n", + "Requirement already satisfied: h11>=0.8 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from uvicorn>=0.18.3->uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.14.0)\n", + "Requirement already satisfied: httptools>=0.5.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.6.1)\n", + "Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.19.0)\n", + "Requirement already satisfied: watchfiles>=0.13 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (0.21.0)\n", + "Requirement already satisfied: websockets>=10.4 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from uvicorn[standard]>=0.18.3->chromadb->pyautogen[retrievechat]>=0.2.3) (11.0.3)\n", + "Requirement already satisfied: executing>=1.2.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from stack-data->ipython->pyautogen[retrievechat]>=0.2.3) (2.0.1)\n", + "Requirement already satisfied: asttokens>=2.1.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from stack-data->ipython->pyautogen[retrievechat]>=0.2.3) (2.4.1)\n", + "Requirement already satisfied: pure-eval in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from stack-data->ipython->pyautogen[retrievechat]>=0.2.3) (0.2.2)\n", + "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (5.3.2)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.3.0)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (4.9)\n", + "Requirement already satisfied: zipp>=0.5 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from importlib-metadata<7.0,>=6.0->opentelemetry-api>=1.2.0->chromadb->pyautogen[retrievechat]>=0.2.3) (3.17.0)\n", + "Requirement already satisfied: humanfriendly>=9.1 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from coloredlogs->onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (10.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from jinja2->torch>=1.11.0->sentence-transformers->pyautogen[retrievechat]>=0.2.3) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from sympy->onnxruntime<2.0,>=1.15->fastembed==0.1.1->qdrant_client[fastembed]) (1.3.0)\n", + "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth>=1.0.1->kubernetes>=28.1.0->chromadb->pyautogen[retrievechat]>=0.2.3) (0.5.0)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } @@ -203,14 +211,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "models to use: ['gpt-4-1106-preview', 'gpt-4-turbo-preview', 'gpt-4-0613', 'gpt-35-turbo-0613', 'gpt-35-turbo-1106']\n" + "models to use: ['gpt4-1106-preview', 'gpt-35-turbo', 'gpt-35-turbo-0613']\n" ] } ], @@ -225,20 +233,7 @@ "# a vector database instance\n", "from autogen.retrieve_utils import TEXT_FORMATS\n", "\n", - "config_list = autogen.config_list_from_json(\n", - " env_or_file=\"OAI_CONFIG_LIST\",\n", - " file_location=\".\",\n", - " filter_dict={\n", - " \"model\": {\n", - " \"gpt-4\",\n", - " \"gpt4\",\n", - " \"gpt-4-32k\",\n", - " \"gpt-4-32k-0314\",\n", - " \"gpt-35-turbo\",\n", - " \"gpt-3.5-turbo\",\n", - " }\n", - " },\n", - ")\n", + "config_list = autogen.config_list_from_json(\"OAI_CONFIG_LIST\")\n", "\n", "assert len(config_list) > 0\n", "print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])" @@ -258,7 +253,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -266,7 +261,7 @@ "output_type": "stream", "text": [ "Accepted file formats for `docs_path`:\n", - "['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n" + "['yml', 'ppt', 'org', 'doc', 'epub', 'rst', 'log', 'docx', 'htm', 'html', 'tsv', 'csv', 'json', 'yaml', 'xlsx', 'pptx', 'rtf', 'msg', 'odt', 'pdf', 'jsonl', 'md', 'xml', 'txt']\n" ] } ], @@ -289,7 +284,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -332,8 +327,7 @@ " \"client\": QdrantClient(\":memory:\"),\n", " \"embedding_model\": \"BAAI/bge-small-en-v1.5\",\n", " },\n", - " # code_execution_config={\n", - " # \"use_docker\": False,}\n", + " code_execution_config=False,\n", ")" ] }, @@ -354,17 +348,42 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Trying to create collection.\n", - "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id 2 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id 1 to context.\u001b[0m\n", + "Trying to create collection.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-07 18:30:12,489 - autogen.agentchat.contrib.qdrant_retrieve_user_proxy_agent - INFO - Found 3 chunks.\u001b[0m\n", + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32mAdding content of doc 0 to context.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -385,8 +404,8 @@ "![Python Version](https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10-blue)\n", "[![Downloads](https://pepy.tech/badge/flaml)](https://pepy.tech/project/flaml)\n", "[![](https://img.shields.io/discord/1025786666260111483?logo=discord&style=flat)](https://discord.gg/Cppx2vSPVP)\n", - "\n", "\n", + "\n", "\n", "# A Fast Library for Automated Machine Learning & Tuning\n", "\n", @@ -405,15 +424,15 @@ "\n", ":fire: FLAML supports Code-First AutoML & Tuning – Private Preview in [Microsoft Fabric Data Science](https://learn.microsoft.com/en-us/fabric/data-science/).\n", "\n", - "\n", "## What is FLAML\n", + "\n", "FLAML is a lightweight Python library for efficient automation of machine\n", "learning and AI operations. It automates workflow based on large language models, machine learning models, etc.\n", "and optimizes their performance.\n", "\n", - "* FLAML enables building next-gen GPT-X applications based on multi-agent conversations with minimal effort. It simplifies the orchestration, automation and optimization of a complex GPT-X workflow. It maximizes the performance of GPT-X models and augments their weakness.\n", - "* For common machine learning tasks like classification and regression, it quickly finds quality models for user-provided data with low computational resources. It is easy to customize or extend. Users can find their desired customizability from a smooth range.\n", - "* It supports fast and economical automatic tuning (e.g., inference hyperparameters for foundation models, configurations in MLOps/LMOps workflows, pipelines, mathematical/statistical models, algorithms, computing experiments, software configurations), capable of handling large search space with heterogeneous evaluation cost and complex constraints/guidance/early stopping.\n", + "- FLAML enables building next-gen GPT-X applications based on multi-agent conversations with minimal effort. It simplifies the orchestration, automation and optimization of a complex GPT-X workflow. It maximizes the performance of GPT-X models and augments their weakness.\n", + "- For common machine learning tasks like classification and regression, it quickly finds quality models for user-provided data with low computational resources. It is easy to customize or extend. Users can find their desired customizability from a smooth range.\n", + "- It supports fast and economical automatic tuning (e.g., inference hyperparameters for foundation models, configurations in MLOps/LMOps workflows, pipelines, mathematical/statistical models, algorithms, computing experiments, software configurations), capable of handling large search space with heterogeneous evaluation cost and complex constraints/guidance/early stopping.\n", "\n", "FLAML is powered by a series of [research studies](https://microsoft.github.io/FLAML/docs/Research/) from Microsoft Research and collaborators such as Penn State University, Stevens Institute of Technology, University of Washington, and University of Waterloo.\n", "\n", @@ -428,6 +447,7 @@ "```\n", "\n", "Minimal dependencies are installed without extra options. You can install extra options based on the feature you need. For example, use the following to install the dependencies needed by the [`autogen`](https://microsoft.github.io/autogen/) package.\n", + "\n", "```bash\n", "pip install \"flaml[autogen]\"\n", "```\n", @@ -437,18 +457,24 @@ "\n", "## Quickstart\n", "\n", - "* (New) The [autogen](https://microsoft.github.io/autogen/) package enables the next-gen GPT-X applications with a generic multi-agent conversation framework.\n", - "It offers customizable and conversable agents which integrate LLMs, tools and human.\n", - "By automating chat among multiple capable agents, one can easily make them collectively perform tasks autonomously or with human feedback, including tasks that require using tools via code. For example,\n", + "- (New) The [autogen](https://microsoft.github.io/autogen/) package enables the next-gen GPT-X applications with a generic multi-agent conversation framework.\n", + " It offers customizable and conversable agents which integrate LLMs, tools and human.\n", + " By automating chat among multiple capable agents, one can easily make them collectively perform tasks autonomously or with human feedback, including tasks that require using tools via code. For example,\n", + "\n", "```python\n", "from flaml import autogen\n", + "\n", "assistant = autogen.AssistantAgent(\"assistant\")\n", "user_proxy = autogen.UserProxyAgent(\"user_proxy\")\n", - "user_proxy.initiate_chat(assistant, message=\"Show me the YTD gain of 10 largest technology companies as of today.\")\n", + "user_proxy.initiate_chat(\n", + " assistant,\n", + " message=\"Show me the YTD gain of 10 largest technology companies as of today.\",\n", + ")\n", "# This initiates an automated chat between the two agents to solve the task\n", "```\n", "\n", "Autogen also helps maximize the utility out of the expensive LLMs such as ChatGPT and GPT-4. It offers a drop-in replacement of `openai.Completion` or `openai.ChatCompletion` with powerful functionalites like tuning, caching, templating, filtering. For example, you can optimize generations by LLM with your own tuning data, success metrics and budgets.\n", + "\n", "```python\n", "# perform tuning\n", "config, analysis = autogen.Completion.tune(\n", @@ -463,30 +489,32 @@ "# perform inference for a test instance\n", "response = autogen.Completion.create(context=test_instance, **config)\n", "```\n", - "* With three lines of code, you can start using this economical and fast\n", - "AutoML engine as a [scikit-learn style estimator](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML).\n", + "\n", + "- With three lines of code, you can start using this economical and fast\n", + " AutoML engine as a [scikit-learn style estimator](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML).\n", "\n", "```python\n", "from flaml import AutoML\n", + "\n", "automl = AutoML()\n", "automl.fit(X_train, y_train, task=\"classification\")\n", "```\n", "\n", - "* You can restrict the learners and use FLAML as a fast hyperparameter tuning\n", - "tool for XGBoost, LightGBM, Random Forest etc. or a [customized learner](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML#estimator-and-search-space).\n", + "- You can restrict the learners and use FLAML as a fast hyperparameter tuning\n", + " tool for XGBoost, LightGBM, Random Forest etc. or a [customized learner](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML#estimator-and-search-space).\n", "\n", "```python\n", "automl.fit(X_train, y_train, task=\"classification\", estimator_list=[\"lgbm\"])\n", "```\n", "\n", - "* You can also run generic hyperparameter tuning for a [custom function](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).\n", + "- You can also run generic hyperparameter tuning for a [custom function](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).\n", "\n", "```python\n", "from flaml import tune\n", "tune.run(evaluation_function, config={…}, low_cost_partial_config={…}, time_budget_s=3600)\n", "```\n", "\n", - "* [Zero-shot AutoML](https://microsoft.github.io/FLAML/docs/Use-Cases/Zero-Shot-AutoML) allows using the existing training API from lightgbm, xgboost etc. while getting the benefit of AutoML in choosing high-performance hyperparameter configurations per task.\n", + "- [Zero-shot AutoML](https://microsoft.github.io/FLAML/docs/Use-Cases/Zero-Shot-AutoML) allows using the existing training API from lightgbm, xgboost etc. while getting the benefit of AutoML in choosing high-performance hyperparameter configurations per task.\n", "\n", "```python\n", "from flaml.default import LGBMRegressor\n", @@ -517,12 +545,98 @@ "Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us\n", "the rights to use your contribution. For details, visit .\n", "\n", - "If you are new to GitHub [here](https://help.github.com/categories/collaborating-with-issues-and-pull-requests/) is a detailed help source on getting involved with development on GitHub.\n", - "# Research\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "No, there is no function called `tune_automl` specifically mentioned in the context provided. However, FLAML does offer general hyperparameter tuning capabilities which could be related to this. In the context of FLAML, there is a generic function called `tune.run()` that can be used for hyperparameter tuning.\n", + "\n", + "Here's a short example of how to use FLAML's tune for a user-defined function based on the given context:\n", + "\n", + "```python\n", + "from flaml import tune\n", + "\n", + "def evaluation_function(config):\n", + " # evaluation logic that returns a metric score\n", + " ...\n", + "\n", + "# define the search space for hyperparameters\n", + "config_search_space = {\n", + " 'max_depth': tune.randint(lower=3, upper=10),\n", + " 'learning_rate': tune.loguniform(lower=1e-4, upper=1e-1),\n", + "}\n", + "\n", + "# run hyperparameter tuning\n", + "tune.run(\n", + " evaluation_function,\n", + " config=config_search_space,\n", + " low_cost_partial_config={'max_depth': 3},\n", + " time_budget_s=3600\n", + ")\n", + "```\n", + "\n", + "Please note that if you are referring to a different kind of function or use case, you might need to specify more details or check the official documentation or source code of the FLAML library.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "UPDATE CONTEXT\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32mAdding content of doc 2 to context.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32mAdding content of doc 1 to context.\u001b[0m\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", + "context provided by the user.\n", + "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", + "For code generation, you must obey the following rules:\n", + "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", + "Rule 2. You must follow the formats below to write your code:\n", + "```language\n", + "# your code\n", + "```\n", + "\n", + "User's question is: Is there a function called tune_automl?\n", + "\n", + "Context is: # Research\n", "\n", "For technical details, please check our research publications.\n", "\n", - "* [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", + "- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", "\n", "```bibtex\n", "@inproceedings{wang2021flaml,\n", @@ -533,7 +647,7 @@ "}\n", "```\n", "\n", - "* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", + "- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", "\n", "```bibtex\n", "@inproceedings{wu2021cfo,\n", @@ -544,7 +658,7 @@ "}\n", "```\n", "\n", - "* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", + "- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", "\n", "```bibtex\n", "@inproceedings{wang2021blendsearch,\n", @@ -555,7 +669,7 @@ "}\n", "```\n", "\n", - "* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", + "- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", "\n", "```bibtex\n", "@inproceedings{liuwang2021hpolm,\n", @@ -566,7 +680,7 @@ "}\n", "```\n", "\n", - "* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", + "- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", "\n", "```bibtex\n", "@inproceedings{wu2021chacha,\n", @@ -577,7 +691,7 @@ "}\n", "```\n", "\n", - "* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", + "- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", "\n", "```bibtex\n", "@inproceedings{wuwang2021fairautoml,\n", @@ -588,7 +702,7 @@ "}\n", "```\n", "\n", - "* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", + "- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", "\n", "```bibtex\n", "@inproceedings{kayaliwang2022default,\n", @@ -599,7 +713,7 @@ "}\n", "```\n", "\n", - "* [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", + "- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", "\n", "```bibtex\n", "@inproceedings{zhang2023targeted,\n", @@ -611,7 +725,7 @@ "}\n", "```\n", "\n", - "* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", + "- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", "\n", "```bibtex\n", "@inproceedings{wang2023EcoOptiGen,\n", @@ -622,7 +736,7 @@ "}\n", "```\n", "\n", - "* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", + "- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", "\n", "```bibtex\n", "@inproceedings{wu2023empirical,\n", @@ -632,7 +746,7 @@ " booktitle={ArXiv preprint arXiv:2306.01337},\n", "}\n", "```\n", - "\n", + "If you are new to GitHub [here](https://help.github.com/categories/collaborating-with-issues-and-pull-requests/) is a detailed help source on getting involved with development on GitHub.\n", "\n", "When you submit a pull request, a CLA bot will automatically determine whether you need to provide\n", "a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions\n", @@ -644,26 +758,10 @@ "\n", "\n", "\n", - "\n", "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "Based on the context provided, which is about the FLAML library, there is no direct reference to a function specifically called `tune_automl`. However, FLAML does offer functionality for automated machine learning (AutoML) and hyperparameter tuning.\n", - "\n", - "The closest reference to an AutoML tuning operation in the given context is shown in the Quickstart section, which demonstrates how to use FLAML as a scikit-learn style estimator for machine learning tasks like classification and regression. It does talk about automated machine learning and tuning, but doesn't mention a function `tune_automl` by name.\n", - "\n", - "If you are looking for a way to perform tuning with FLAML, the context indicates you can use the `tune` module to run generic hyperparameter tuning for a custom function, as demonstrated in the Quickstart section:\n", - "\n", - "```python\n", - "from flaml import tune\n", - "tune.run(evaluation_function, config={…}, low_cost_partial_config={…}, time_budget_s=3600)\n", - "```\n", - "\n", - "This is not called `tune_automl` but rather just `tune.run`.\n", - "\n", - "If you need confirmation on whether a function called `tune_automl` specifically exists, the FLAML documentation or its API reference should be checked. If documentation is not enough to confirm and you require to look into the actual code or a structured list of functionalities provided by FLAML, that information isn't available in the given context.\n", - "\n", - "In that case, the instruction should be: `UPDATE CONTEXT`.\n", + "UPDATE CONTEXT\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[32mUpdating context and resetting conversation.\u001b[0m\n", @@ -678,10 +776,10 @@ { "data": { "text/plain": [ - "ChatResult(chat_id=None, chat_history=[{'content': 'TERMINATE', 'role': 'assistant'}], summary='', cost=({'total_cost': 0.12719999999999998, 'gpt-4': {'cost': 0.12719999999999998, 'prompt_tokens': 3634, 'completion_tokens': 303, 'total_tokens': 3937}}, {'total_cost': 0.12719999999999998, 'gpt-4': {'cost': 0.12719999999999998, 'prompt_tokens': 3634, 'completion_tokens': 303, 'total_tokens': 3937}}), human_input=[])" + "ChatResult(chat_id=None, chat_history=[{'content': 'TERMINATE', 'role': 'assistant'}], summary='', cost=({'total_cost': 0.19977, 'gpt-4': {'cost': 0.19977, 'prompt_tokens': 6153, 'completion_tokens': 253, 'total_tokens': 6406}}, {'total_cost': 0.19977, 'gpt-4': {'cost': 0.19977, 'prompt_tokens': 6153, 'completion_tokens': 253, 'total_tokens': 6406}}), human_input=[])" ] }, - "execution_count": 20, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -711,16 +809,34 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 12, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32mAdding content of doc 2 to context.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Model gpt4-1106-preview not found. Using cl100k_base encoding.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32mAdding doc_id 2 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", - "\u001b[32mAdding doc_id 1 to context.\u001b[0m\n", "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", "\n", "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", @@ -739,7 +855,7 @@ "\n", "For technical details, please check our research publications.\n", "\n", - "* [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", + "- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", "\n", "```bibtex\n", "@inproceedings{wang2021flaml,\n", @@ -750,7 +866,7 @@ "}\n", "```\n", "\n", - "* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", + "- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", "\n", "```bibtex\n", "@inproceedings{wu2021cfo,\n", @@ -761,7 +877,7 @@ "}\n", "```\n", "\n", - "* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", + "- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", "\n", "```bibtex\n", "@inproceedings{wang2021blendsearch,\n", @@ -772,7 +888,7 @@ "}\n", "```\n", "\n", - "* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", + "- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", "\n", "```bibtex\n", "@inproceedings{liuwang2021hpolm,\n", @@ -783,7 +899,7 @@ "}\n", "```\n", "\n", - "* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", + "- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", "\n", "```bibtex\n", "@inproceedings{wu2021chacha,\n", @@ -794,7 +910,7 @@ "}\n", "```\n", "\n", - "* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", + "- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", "\n", "```bibtex\n", "@inproceedings{wuwang2021fairautoml,\n", @@ -805,7 +921,7 @@ "}\n", "```\n", "\n", - "* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", + "- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", "\n", "```bibtex\n", "@inproceedings{kayaliwang2022default,\n", @@ -816,7 +932,7 @@ "}\n", "```\n", "\n", - "* [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", + "- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", "\n", "```bibtex\n", "@inproceedings{zhang2023targeted,\n", @@ -828,7 +944,7 @@ "}\n", "```\n", "\n", - "* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", + "- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", "\n", "```bibtex\n", "@inproceedings{wang2023EcoOptiGen,\n", @@ -839,7 +955,7 @@ "}\n", "```\n", "\n", - "* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", + "- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", "\n", "```bibtex\n", "@inproceedings{wu2023empirical,\n", @@ -850,161 +966,18 @@ "}\n", "```\n", "\n", - "[![PyPI version](https://badge.fury.io/py/FLAML.svg)](https://badge.fury.io/py/FLAML)\n", - "![Conda version](https://img.shields.io/conda/vn/conda-forge/flaml)\n", - "[![Build](https://github.com/microsoft/FLAML/actions/workflows/python-package.yml/badge.svg)](https://github.com/microsoft/FLAML/actions/workflows/python-package.yml)\n", - "![Python Version](https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10-blue)\n", - "[![Downloads](https://pepy.tech/badge/flaml)](https://pepy.tech/project/flaml)\n", - "[![](https://img.shields.io/discord/1025786666260111483?logo=discord&style=flat)](https://discord.gg/Cppx2vSPVP)\n", - "\n", "\n", "\n", - "# A Fast Library for Automated Machine Learning & Tuning\n", - "\n", - "

\n", - " \n", - "
\n", - "

\n", - "\n", - ":fire: Heads-up: We have migrated [AutoGen](https://microsoft.github.io/autogen/) into a dedicated [github repository](https://github.com/microsoft/autogen). Alongside this move, we have also launched a dedicated [Discord](https://discord.gg/pAbnFJrkgZ) server and a [website](https://microsoft.github.io/autogen/) for comprehensive documentation.\n", - "\n", - ":fire: The automated multi-agent chat framework in [AutoGen](https://microsoft.github.io/autogen/) is in preview from v2.0.0.\n", - "\n", - ":fire: FLAML is highlighted in OpenAI's [cookbook](https://github.com/openai/openai-cookbook#related-resources-from-around-the-web).\n", - "\n", - ":fire: [autogen](https://microsoft.github.io/autogen/) is released with support for ChatGPT and GPT-4, based on [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673).\n", - "\n", - ":fire: FLAML supports Code-First AutoML & Tuning – Private Preview in [Microsoft Fabric Data Science](https://learn.microsoft.com/en-us/fabric/data-science/).\n", - "\n", - "\n", - "## What is FLAML\n", - "FLAML is a lightweight Python library for efficient automation of machine\n", - "learning and AI operations. It automates workflow based on large language models, machine learning models, etc.\n", - "and optimizes their performance.\n", - "\n", - "* FLAML enables building next-gen GPT-X applications based on multi-agent conversations with minimal effort. It simplifies the orchestration, automation and optimization of a complex GPT-X workflow. It maximizes the performance of GPT-X models and augments their weakness.\n", - "* For common machine learning tasks like classification and regression, it quickly finds quality models for user-provided data with low computational resources. It is easy to customize or extend. Users can find their desired customizability from a smooth range.\n", - "* It supports fast and economical automatic tuning (e.g., inference hyperparameters for foundation models, configurations in MLOps/LMOps workflows, pipelines, mathematical/statistical models, algorithms, computing experiments, software configurations), capable of handling large search space with heterogeneous evaluation cost and complex constraints/guidance/early stopping.\n", - "\n", - "FLAML is powered by a series of [research studies](https://microsoft.github.io/FLAML/docs/Research/) from Microsoft Research and collaborators such as Penn State University, Stevens Institute of Technology, University of Washington, and University of Waterloo.\n", - "\n", - "FLAML has a .NET implementation in [ML.NET](http://dot.net/ml), an open-source, cross-platform machine learning framework for .NET.\n", - "\n", - "## Installation\n", - "\n", - "FLAML requires **Python version >= 3.8**. It can be installed from pip:\n", - "\n", - "```bash\n", - "pip install flaml\n", - "```\n", - "\n", - "Minimal dependencies are installed without extra options. You can install extra options based on the feature you need. For example, use the following to install the dependencies needed by the [`autogen`](https://microsoft.github.io/autogen/) package.\n", - "```bash\n", - "pip install \"flaml[autogen]\"\n", - "```\n", - "\n", - "Find more options in [Installation](https://microsoft.github.io/FLAML/docs/Installation).\n", - "Each of the [`notebook examples`](https://github.com/microsoft/FLAML/tree/main/notebook) may require a specific option to be installed.\n", - "\n", - "## Quickstart\n", - "\n", - "* (New) The [autogen](https://microsoft.github.io/autogen/) package enables the next-gen GPT-X applications with a generic multi-agent conversation framework.\n", - "It offers customizable and conversable agents which integrate LLMs, tools and human.\n", - "By automating chat among multiple capable agents, one can easily make them collectively perform tasks autonomously or with human feedback, including tasks that require using tools via code. For example,\n", - "```python\n", - "from flaml import autogen\n", - "assistant = autogen.AssistantAgent(\"assistant\")\n", - "user_proxy = autogen.UserProxyAgent(\"user_proxy\")\n", - "user_proxy.initiate_chat(assistant, message=\"Show me the YTD gain of 10 largest technology companies as of today.\")\n", - "# This initiates an automated chat between the two agents to solve the task\n", - "```\n", - "\n", - "Autogen also helps maximize the utility out of the expensive LLMs such as ChatGPT and GPT-4. It offers a drop-in replacement of `openai.Completion` or `openai.ChatCompletion` with powerful functionalites like tuning, caching, templating, filtering. For example, you can optimize generations by LLM with your own tuning data, success metrics and budgets.\n", - "```python\n", - "# perform tuning\n", - "config, analysis = autogen.Completion.tune(\n", - " data=tune_data,\n", - " metric=\"success\",\n", - " mode=\"max\",\n", - " eval_func=eval_func,\n", - " inference_budget=0.05,\n", - " optimization_budget=3,\n", - " num_samples=-1,\n", - ")\n", - "# perform inference for a test instance\n", - "response = autogen.Completion.create(context=test_instance, **config)\n", - "```\n", - "* With three lines of code, you can start using this economical and fast\n", - "AutoML engine as a [scikit-learn style estimator](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML).\n", - "\n", - "```python\n", - "from flaml import AutoML\n", - "automl = AutoML()\n", - "automl.fit(X_train, y_train, task=\"classification\")\n", - "```\n", - "\n", - "* You can restrict the learners and use FLAML as a fast hyperparameter tuning\n", - "tool for XGBoost, LightGBM, Random Forest etc. or a [customized learner](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML#estimator-and-search-space).\n", - "\n", - "```python\n", - "automl.fit(X_train, y_train, task=\"classification\", estimator_list=[\"lgbm\"])\n", - "```\n", - "\n", - "* You can also run generic hyperparameter tuning for a [custom function](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).\n", - "\n", - "```python\n", - "from flaml import tune\n", - "tune.run(evaluation_function, config={…}, low_cost_partial_config={…}, time_budget_s=3600)\n", - "```\n", - "\n", - "* [Zero-shot AutoML](https://microsoft.github.io/FLAML/docs/Use-Cases/Zero-Shot-AutoML) allows using the existing training API from lightgbm, xgboost etc. while getting the benefit of AutoML in choosing high-performance hyperparameter configurations per task.\n", - "\n", - "```python\n", - "from flaml.default import LGBMRegressor\n", - "\n", - "# Use LGBMRegressor in the same way as you use lightgbm.LGBMRegressor.\n", - "estimator = LGBMRegressor()\n", - "# The hyperparameters are automatically set according to the training data.\n", - "estimator.fit(X_train, y_train)\n", - "```\n", - "\n", - "## Documentation\n", - "\n", - "You can find a detailed documentation about FLAML [here](https://microsoft.github.io/FLAML/).\n", - "\n", - "In addition, you can find:\n", - "\n", - "- [Research](https://microsoft.github.io/FLAML/docs/Research) and [blogposts](https://microsoft.github.io/FLAML/blog) around FLAML.\n", - "\n", - "- [Discord](https://discord.gg/Cppx2vSPVP).\n", - "\n", - "- [Contributing guide](https://microsoft.github.io/FLAML/docs/Contribute).\n", - "\n", - "- ML.NET documentation and tutorials for [Model Builder](https://learn.microsoft.com/dotnet/machine-learning/tutorials/predict-prices-with-model-builder), [ML.NET CLI](https://learn.microsoft.com/dotnet/machine-learning/tutorials/sentiment-analysis-cli), and [AutoML API](https://learn.microsoft.com/dotnet/machine-learning/how-to-guides/how-to-use-the-automl-api).\n", - "\n", - "## Contributing\n", - "\n", - "This project welcomes contributions and suggestions. Most contributions require you to agree to a\n", - "Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us\n", - "the rights to use your contribution. For details, visit .\n", - "\n", - "If you are new to GitHub [here](https://help.github.com/categories/collaborating-with-issues-and-pull-requests/) is a detailed help source on getting involved with development on GitHub.\n", - "\n", - "When you submit a pull request, a CLA bot will automatically determine whether you need to provide\n", - "a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions\n", - "provided by the bot. You will only need to do this once across all repos using our CLA.\n", - "\n", - "This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\n", - "For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or\n", - "contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.\n", - "\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "The author of FLAML is Chi Wang, along with other collaborators including Qingyun Wu, Markus Weimer, Erkang Zhu, Silu Huang, Amin Saied, Susan Xueqing Liu, John Langford, Paul Mineiro, Marco Rossi, Moe Kayali, Shaokun Zhang, Feiran Jia, Yiran Wu, Hangyu Li, Yue Wang, Yin Tat Lee, Richard Peng, and Ahmed H. Awadallah, as indicated in the provided references for FLAML's research publications.\n", + "The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -1012,10 +985,10 @@ { "data": { "text/plain": [ - "ChatResult(chat_id=None, chat_history=[{'content': 'You\\'re a retrieve augmented coding assistant. You answer user\\'s questions based on your own knowledge and the\\ncontext provided by the user.\\nIf you can\\'t answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\\nFor code generation, you must obey the following rules:\\nRule 1. You MUST NOT install any packages because all the packages needed are already installed.\\nRule 2. You must follow the formats below to write your code:\\n```language\\n# your code\\n```\\n\\nUser\\'s question is: Who is the author of FLAML?\\n\\nContext is: # Research\\n\\nFor technical details, please check our research publications.\\n\\n* [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\\n\\n```bibtex\\n@inproceedings{wang2021flaml,\\n title={FLAML: A Fast and Lightweight AutoML Library},\\n author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\\n year={2021},\\n booktitle={MLSys},\\n}\\n```\\n\\n* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\\n\\n```bibtex\\n@inproceedings{wu2021cfo,\\n title={Frugal Optimization for Cost-related Hyperparameters},\\n author={Qingyun Wu and Chi Wang and Silu Huang},\\n year={2021},\\n booktitle={AAAI},\\n}\\n```\\n\\n* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\\n\\n```bibtex\\n@inproceedings{wang2021blendsearch,\\n title={Economical Hyperparameter Optimization With Blended Search Strategy},\\n author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\\n year={2021},\\n booktitle={ICLR},\\n}\\n```\\n\\n* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\\n\\n```bibtex\\n@inproceedings{liuwang2021hpolm,\\n title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\\n author={Susan Xueqing Liu and Chi Wang},\\n year={2021},\\n booktitle={ACL},\\n}\\n```\\n\\n* [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\\n\\n```bibtex\\n@inproceedings{wu2021chacha,\\n title={ChaCha for Online AutoML},\\n author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\\n year={2021},\\n booktitle={ICML},\\n}\\n```\\n\\n* [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\\n\\n```bibtex\\n@inproceedings{wuwang2021fairautoml,\\n title={Fair AutoML},\\n author={Qingyun Wu and Chi Wang},\\n year={2021},\\n booktitle={ArXiv preprint arXiv:2111.06495},\\n}\\n```\\n\\n* [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\\n\\n```bibtex\\n@inproceedings{kayaliwang2022default,\\n title={Mining Robust Default Configurations for Resource-constrained AutoML},\\n author={Moe Kayali and Chi Wang},\\n year={2022},\\n booktitle={ArXiv preprint arXiv:2202.09927},\\n}\\n```\\n\\n* [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\\n\\n```bibtex\\n@inproceedings{zhang2023targeted,\\n title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\\n author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\\n booktitle={International Conference on Learning Representations},\\n year={2023},\\n url={https://openreview.net/forum?id=0Ij9_q567Ma},\\n}\\n```\\n\\n* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\\n\\n```bibtex\\n@inproceedings{wang2023EcoOptiGen,\\n title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\\n author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\\n year={2023},\\n booktitle={ArXiv preprint arXiv:2303.04673},\\n}\\n```\\n\\n* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\\n\\n```bibtex\\n@inproceedings{wu2023empirical,\\n title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\\n author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\\n year={2023},\\n booktitle={ArXiv preprint arXiv:2306.01337},\\n}\\n```\\n\\n[![PyPI version](https://badge.fury.io/py/FLAML.svg)](https://badge.fury.io/py/FLAML)\\n![Conda version](https://img.shields.io/conda/vn/conda-forge/flaml)\\n[![Build](https://github.com/microsoft/FLAML/actions/workflows/python-package.yml/badge.svg)](https://github.com/microsoft/FLAML/actions/workflows/python-package.yml)\\n![Python Version](https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10-blue)\\n[![Downloads](https://pepy.tech/badge/flaml)](https://pepy.tech/project/flaml)\\n[![](https://img.shields.io/discord/1025786666260111483?logo=discord&style=flat)](https://discord.gg/Cppx2vSPVP)\\n\\n\\n\\n# A Fast Library for Automated Machine Learning & Tuning\\n\\n

\\n \\n
\\n

\\n\\n:fire: Heads-up: We have migrated [AutoGen](https://microsoft.github.io/autogen/) into a dedicated [github repository](https://github.com/microsoft/autogen). Alongside this move, we have also launched a dedicated [Discord](https://discord.gg/pAbnFJrkgZ) server and a [website](https://microsoft.github.io/autogen/) for comprehensive documentation.\\n\\n:fire: The automated multi-agent chat framework in [AutoGen](https://microsoft.github.io/autogen/) is in preview from v2.0.0.\\n\\n:fire: FLAML is highlighted in OpenAI\\'s [cookbook](https://github.com/openai/openai-cookbook#related-resources-from-around-the-web).\\n\\n:fire: [autogen](https://microsoft.github.io/autogen/) is released with support for ChatGPT and GPT-4, based on [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673).\\n\\n:fire: FLAML supports Code-First AutoML & Tuning – Private Preview in [Microsoft Fabric Data Science](https://learn.microsoft.com/en-us/fabric/data-science/).\\n\\n\\n## What is FLAML\\nFLAML is a lightweight Python library for efficient automation of machine\\nlearning and AI operations. It automates workflow based on large language models, machine learning models, etc.\\nand optimizes their performance.\\n\\n* FLAML enables building next-gen GPT-X applications based on multi-agent conversations with minimal effort. It simplifies the orchestration, automation and optimization of a complex GPT-X workflow. It maximizes the performance of GPT-X models and augments their weakness.\\n* For common machine learning tasks like classification and regression, it quickly finds quality models for user-provided data with low computational resources. It is easy to customize or extend. Users can find their desired customizability from a smooth range.\\n* It supports fast and economical automatic tuning (e.g., inference hyperparameters for foundation models, configurations in MLOps/LMOps workflows, pipelines, mathematical/statistical models, algorithms, computing experiments, software configurations), capable of handling large search space with heterogeneous evaluation cost and complex constraints/guidance/early stopping.\\n\\nFLAML is powered by a series of [research studies](https://microsoft.github.io/FLAML/docs/Research/) from Microsoft Research and collaborators such as Penn State University, Stevens Institute of Technology, University of Washington, and University of Waterloo.\\n\\nFLAML has a .NET implementation in [ML.NET](http://dot.net/ml), an open-source, cross-platform machine learning framework for .NET.\\n\\n## Installation\\n\\nFLAML requires **Python version >= 3.8**. It can be installed from pip:\\n\\n```bash\\npip install flaml\\n```\\n\\nMinimal dependencies are installed without extra options. You can install extra options based on the feature you need. For example, use the following to install the dependencies needed by the [`autogen`](https://microsoft.github.io/autogen/) package.\\n```bash\\npip install \"flaml[autogen]\"\\n```\\n\\nFind more options in [Installation](https://microsoft.github.io/FLAML/docs/Installation).\\nEach of the [`notebook examples`](https://github.com/microsoft/FLAML/tree/main/notebook) may require a specific option to be installed.\\n\\n## Quickstart\\n\\n* (New) The [autogen](https://microsoft.github.io/autogen/) package enables the next-gen GPT-X applications with a generic multi-agent conversation framework.\\nIt offers customizable and conversable agents which integrate LLMs, tools and human.\\nBy automating chat among multiple capable agents, one can easily make them collectively perform tasks autonomously or with human feedback, including tasks that require using tools via code. For example,\\n```python\\nfrom flaml import autogen\\nassistant = autogen.AssistantAgent(\"assistant\")\\nuser_proxy = autogen.UserProxyAgent(\"user_proxy\")\\nuser_proxy.initiate_chat(assistant, message=\"Show me the YTD gain of 10 largest technology companies as of today.\")\\n# This initiates an automated chat between the two agents to solve the task\\n```\\n\\nAutogen also helps maximize the utility out of the expensive LLMs such as ChatGPT and GPT-4. It offers a drop-in replacement of `openai.Completion` or `openai.ChatCompletion` with powerful functionalites like tuning, caching, templating, filtering. For example, you can optimize generations by LLM with your own tuning data, success metrics and budgets.\\n```python\\n# perform tuning\\nconfig, analysis = autogen.Completion.tune(\\n data=tune_data,\\n metric=\"success\",\\n mode=\"max\",\\n eval_func=eval_func,\\n inference_budget=0.05,\\n optimization_budget=3,\\n num_samples=-1,\\n)\\n# perform inference for a test instance\\nresponse = autogen.Completion.create(context=test_instance, **config)\\n```\\n* With three lines of code, you can start using this economical and fast\\nAutoML engine as a [scikit-learn style estimator](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML).\\n\\n```python\\nfrom flaml import AutoML\\nautoml = AutoML()\\nautoml.fit(X_train, y_train, task=\"classification\")\\n```\\n\\n* You can restrict the learners and use FLAML as a fast hyperparameter tuning\\ntool for XGBoost, LightGBM, Random Forest etc. or a [customized learner](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML#estimator-and-search-space).\\n\\n```python\\nautoml.fit(X_train, y_train, task=\"classification\", estimator_list=[\"lgbm\"])\\n```\\n\\n* You can also run generic hyperparameter tuning for a [custom function](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).\\n\\n```python\\nfrom flaml import tune\\ntune.run(evaluation_function, config={…}, low_cost_partial_config={…}, time_budget_s=3600)\\n```\\n\\n* [Zero-shot AutoML](https://microsoft.github.io/FLAML/docs/Use-Cases/Zero-Shot-AutoML) allows using the existing training API from lightgbm, xgboost etc. while getting the benefit of AutoML in choosing high-performance hyperparameter configurations per task.\\n\\n```python\\nfrom flaml.default import LGBMRegressor\\n\\n# Use LGBMRegressor in the same way as you use lightgbm.LGBMRegressor.\\nestimator = LGBMRegressor()\\n# The hyperparameters are automatically set according to the training data.\\nestimator.fit(X_train, y_train)\\n```\\n\\n## Documentation\\n\\nYou can find a detailed documentation about FLAML [here](https://microsoft.github.io/FLAML/).\\n\\nIn addition, you can find:\\n\\n- [Research](https://microsoft.github.io/FLAML/docs/Research) and [blogposts](https://microsoft.github.io/FLAML/blog) around FLAML.\\n\\n- [Discord](https://discord.gg/Cppx2vSPVP).\\n\\n- [Contributing guide](https://microsoft.github.io/FLAML/docs/Contribute).\\n\\n- ML.NET documentation and tutorials for [Model Builder](https://learn.microsoft.com/dotnet/machine-learning/tutorials/predict-prices-with-model-builder), [ML.NET CLI](https://learn.microsoft.com/dotnet/machine-learning/tutorials/sentiment-analysis-cli), and [AutoML API](https://learn.microsoft.com/dotnet/machine-learning/how-to-guides/how-to-use-the-automl-api).\\n\\n## Contributing\\n\\nThis project welcomes contributions and suggestions. Most contributions require you to agree to a\\nContributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us\\nthe rights to use your contribution. For details, visit .\\n\\nIf you are new to GitHub [here](https://help.github.com/categories/collaborating-with-issues-and-pull-requests/) is a detailed help source on getting involved with development on GitHub.\\n\\nWhen you submit a pull request, a CLA bot will automatically determine whether you need to provide\\na CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions\\nprovided by the bot. You will only need to do this once across all repos using our CLA.\\n\\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\\nFor more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or\\ncontact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.\\n\\n\\n', 'role': 'assistant'}, {'content': \"The author of FLAML is Chi Wang, along with other collaborators including Qingyun Wu, Markus Weimer, Erkang Zhu, Silu Huang, Amin Saied, Susan Xueqing Liu, John Langford, Paul Mineiro, Marco Rossi, Moe Kayali, Shaokun Zhang, Feiran Jia, Yiran Wu, Hangyu Li, Yue Wang, Yin Tat Lee, Richard Peng, and Ahmed H. Awadallah, as indicated in the provided references for FLAML's research publications.\", 'role': 'user'}], summary=\"The author of FLAML is Chi Wang, along with other collaborators including Qingyun Wu, Markus Weimer, Erkang Zhu, Silu Huang, Amin Saied, Susan Xueqing Liu, John Langford, Paul Mineiro, Marco Rossi, Moe Kayali, Shaokun Zhang, Feiran Jia, Yiran Wu, Hangyu Li, Yue Wang, Yin Tat Lee, Richard Peng, and Ahmed H. Awadallah, as indicated in the provided references for FLAML's research publications.\", cost=({'total_cost': 0.11538, 'gpt-4': {'cost': 0.11538, 'prompt_tokens': 3632, 'completion_tokens': 107, 'total_tokens': 3739}}, {'total_cost': 0}), human_input=[])" + "ChatResult(chat_id=None, chat_history=[{'content': \"You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\\ncontext provided by the user.\\nIf you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\\nFor code generation, you must obey the following rules:\\nRule 1. You MUST NOT install any packages because all the packages needed are already installed.\\nRule 2. You must follow the formats below to write your code:\\n```language\\n# your code\\n```\\n\\nUser's question is: Who is the author of FLAML?\\n\\nContext is: # Research\\n\\nFor technical details, please check our research publications.\\n\\n- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\\n\\n```bibtex\\n@inproceedings{wang2021flaml,\\n title={FLAML: A Fast and Lightweight AutoML Library},\\n author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\\n year={2021},\\n booktitle={MLSys},\\n}\\n```\\n\\n- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\\n\\n```bibtex\\n@inproceedings{wu2021cfo,\\n title={Frugal Optimization for Cost-related Hyperparameters},\\n author={Qingyun Wu and Chi Wang and Silu Huang},\\n year={2021},\\n booktitle={AAAI},\\n}\\n```\\n\\n- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\\n\\n```bibtex\\n@inproceedings{wang2021blendsearch,\\n title={Economical Hyperparameter Optimization With Blended Search Strategy},\\n author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\\n year={2021},\\n booktitle={ICLR},\\n}\\n```\\n\\n- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\\n\\n```bibtex\\n@inproceedings{liuwang2021hpolm,\\n title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\\n author={Susan Xueqing Liu and Chi Wang},\\n year={2021},\\n booktitle={ACL},\\n}\\n```\\n\\n- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\\n\\n```bibtex\\n@inproceedings{wu2021chacha,\\n title={ChaCha for Online AutoML},\\n author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\\n year={2021},\\n booktitle={ICML},\\n}\\n```\\n\\n- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\\n\\n```bibtex\\n@inproceedings{wuwang2021fairautoml,\\n title={Fair AutoML},\\n author={Qingyun Wu and Chi Wang},\\n year={2021},\\n booktitle={ArXiv preprint arXiv:2111.06495},\\n}\\n```\\n\\n- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\\n\\n```bibtex\\n@inproceedings{kayaliwang2022default,\\n title={Mining Robust Default Configurations for Resource-constrained AutoML},\\n author={Moe Kayali and Chi Wang},\\n year={2022},\\n booktitle={ArXiv preprint arXiv:2202.09927},\\n}\\n```\\n\\n- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\\n\\n```bibtex\\n@inproceedings{zhang2023targeted,\\n title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\\n author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\\n booktitle={International Conference on Learning Representations},\\n year={2023},\\n url={https://openreview.net/forum?id=0Ij9_q567Ma},\\n}\\n```\\n\\n- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\\n\\n```bibtex\\n@inproceedings{wang2023EcoOptiGen,\\n title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\\n author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\\n year={2023},\\n booktitle={ArXiv preprint arXiv:2303.04673},\\n}\\n```\\n\\n- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\\n\\n```bibtex\\n@inproceedings{wu2023empirical,\\n title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\\n author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\\n year={2023},\\n booktitle={ArXiv preprint arXiv:2306.01337},\\n}\\n```\\n\\n\", 'role': 'assistant'}, {'content': 'The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.', 'role': 'user'}], summary='The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.', cost=({'total_cost': 0.04596, 'gpt-4': {'cost': 0.04596, 'prompt_tokens': 1486, 'completion_tokens': 23, 'total_tokens': 1509}}, {'total_cost': 0.04596, 'gpt-4': {'cost': 0.04596, 'prompt_tokens': 1486, 'completion_tokens': 23, 'total_tokens': 1509}}), human_input=[])" ] }, - "execution_count": 18, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1031,8 +1004,10 @@ ], "metadata": { "front_matter": { - "tags": ["rag"], - "description": "This notebook demonstrates the usage of QdrantRetrieveUserProxyAgent for RAG." + "description": "This notebook demonstrates the usage of QdrantRetrieveUserProxyAgent for RAG.", + "tags": [ + "rag" + ] }, "kernelspec": { "display_name": "Python 3 (ipykernel)",