mirror of https://github.com/microsoft/autogen.git
Switched to AzureOpenAI for api_type=="azure" (#1232)
* Switched to AzureOpenAI for api_type=="azure" * Setting AzureOpenAI to empty object if no `openai` * extra_ and openai_ kwargs * test_client, support for Azure and "gpt-35-turbo-instruct" * instruct/azure model in test_client_stream * generalize aoai support (#1) * generalize aoai support * Null check, fixing tests * cleanup test --------- Co-authored-by: Maxim Saplin <smaxmail@gmail.com> * Returning back model names for instruct * process model in create * None check --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
39182ccb6b
commit
00dbcb247e
|
@ -11,7 +11,7 @@ from pydantic import BaseModel
|
|||
|
||||
from autogen.oai import completion
|
||||
|
||||
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
|
||||
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
|
||||
from autogen.token_count_utils import count_token
|
||||
from autogen._pydantic import model_dump
|
||||
|
||||
|
@ -21,9 +21,10 @@ try:
|
|||
except ImportError:
|
||||
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
|
||||
OpenAI = object
|
||||
AzureOpenAI = object
|
||||
else:
|
||||
# raises exception if openai>=1 is installed and something is wrong with imports
|
||||
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
|
||||
from openai import OpenAI, AzureOpenAI, APIError, __version__ as OPENAIVERSION
|
||||
from openai.resources import Completions
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
|
||||
|
@ -52,8 +53,18 @@ class OpenAIWrapper:
|
|||
"""A wrapper class for openai client."""
|
||||
|
||||
cache_path_root: str = ".cache"
|
||||
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
|
||||
extra_kwargs = {
|
||||
"cache_seed",
|
||||
"filter_func",
|
||||
"allow_format_str_template",
|
||||
"context",
|
||||
"api_version",
|
||||
"api_type",
|
||||
"tags",
|
||||
}
|
||||
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
|
||||
aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
|
||||
openai_kwargs = openai_kwargs | aopenai_kwargs
|
||||
total_usage_summary: Optional[Dict[str, Any]] = None
|
||||
actual_usage_summary: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
@ -105,46 +116,10 @@ class OpenAIWrapper:
|
|||
self._clients = [self._client(extra_kwargs, openai_config)]
|
||||
self._config_list = [extra_kwargs]
|
||||
|
||||
def _process_for_azure(
|
||||
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
|
||||
) -> None:
|
||||
# deal with api_version
|
||||
query_segment = f"{segment}_query"
|
||||
headers_segment = f"{segment}_headers"
|
||||
api_version = extra_kwargs.get("api_version")
|
||||
if api_version is not None and query_segment not in config:
|
||||
config[query_segment] = {"api-version": api_version}
|
||||
if segment == "default":
|
||||
# remove the api_version from extra_kwargs
|
||||
extra_kwargs.pop("api_version")
|
||||
if segment == "extra":
|
||||
return
|
||||
# deal with api_type
|
||||
api_type = extra_kwargs.get("api_type")
|
||||
if api_type is not None and api_type.startswith("azure") and headers_segment not in config:
|
||||
api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY"))
|
||||
config[headers_segment] = {"api-key": api_key}
|
||||
# remove the api_type from extra_kwargs
|
||||
extra_kwargs.pop("api_type")
|
||||
# deal with model
|
||||
model = extra_kwargs.get("model")
|
||||
if model is None:
|
||||
return
|
||||
if "gpt-3.5" in model:
|
||||
# hack for azure gpt-3.5
|
||||
extra_kwargs["model"] = model = model.replace("gpt-3.5", "gpt-35")
|
||||
base_url = config.get("base_url")
|
||||
if base_url is None:
|
||||
raise ValueError("to use azure openai api, base_url must be specified.")
|
||||
suffix = f"/openai/deployments/{model}"
|
||||
if not base_url.endswith(suffix):
|
||||
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
|
||||
|
||||
def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Separate the config into openai_config and extra_kwargs."""
|
||||
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
|
||||
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
|
||||
self._process_for_azure(openai_config, extra_kwargs)
|
||||
return openai_config, extra_kwargs
|
||||
|
||||
def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
@ -156,10 +131,22 @@ class OpenAIWrapper:
|
|||
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
|
||||
"""Create a client with the given config to override openai_config,
|
||||
after removing extra kwargs.
|
||||
|
||||
For Azure models/deployment names there's a convenience modification of model removing dots in
|
||||
the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
|
||||
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
|
||||
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
|
||||
"""
|
||||
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
|
||||
self._process_for_azure(openai_config, config)
|
||||
client = OpenAI(**openai_config)
|
||||
api_type = config.get("api_type")
|
||||
if api_type is not None and api_type.startswith("azure"):
|
||||
openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
|
||||
if openai_config["azure_deployment"] is not None:
|
||||
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
|
||||
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
|
||||
client = AzureOpenAI(**openai_config)
|
||||
else:
|
||||
client = OpenAI(**openai_config)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
|
@ -242,8 +229,9 @@ class OpenAIWrapper:
|
|||
full_config = {**config, **self._config_list[i]}
|
||||
# separate the config into create_config and extra_kwargs
|
||||
create_config, extra_kwargs = self._separate_create_config(full_config)
|
||||
# process for azure
|
||||
self._process_for_azure(create_config, extra_kwargs, "extra")
|
||||
api_type = extra_kwargs.get("api_type")
|
||||
if api_type and api_type.startswith("azure") and "model" in create_config:
|
||||
create_config["model"] = create_config["model"].replace(".", "")
|
||||
# construct the create params
|
||||
params = self._construct_create_params(create_config, extra_kwargs)
|
||||
# get the cache_seed, filter_func and context
|
||||
|
|
|
@ -31,10 +31,15 @@ def test_aoai_chat_completion():
|
|||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
# for config in config_list:
|
||||
# print(config)
|
||||
# client = OpenAIWrapper(**config)
|
||||
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
||||
print(response)
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
# test dialect
|
||||
config = config_list[0]
|
||||
config["azure_deployment"] = config["model"]
|
||||
config["azure_endpoint"] = config.pop("base_url")
|
||||
client = OpenAIWrapper(**config)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
||||
print(response)
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
@ -93,21 +98,23 @@ def test_chat_completion():
|
|||
def test_completion():
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
response = client.create(prompt="1+1=", model=model)
|
||||
print(response)
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
@pytest.mark.parametrize(
|
||||
"cache_seed, model",
|
||||
"cache_seed",
|
||||
[
|
||||
(None, "gpt-3.5-turbo-instruct"),
|
||||
(42, "gpt-3.5-turbo-instruct"),
|
||||
None,
|
||||
42,
|
||||
],
|
||||
)
|
||||
def test_cost(cache_seed, model):
|
||||
def test_cost(cache_seed):
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
|
||||
response = client.create(prompt="1+3=", model=model)
|
||||
print(response.cost)
|
||||
|
@ -117,7 +124,8 @@ def test_cost(cache_seed, model):
|
|||
def test_usage_summary():
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=None)
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
response = client.create(prompt="1+3=", model=model, cache_seed=None)
|
||||
|
||||
# usage should be recorded
|
||||
assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
|
||||
|
@ -138,15 +146,15 @@ def test_usage_summary():
|
|||
assert client.total_usage_summary is None, "total_usage_summary should be None"
|
||||
|
||||
# actual usage and all usage should be different
|
||||
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=42)
|
||||
response = client.create(prompt="1+3=", model=model, cache_seed=42)
|
||||
assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
|
||||
assert client.actual_usage_summary is None, "No actual cost should be recorded"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_aoai_chat_completion()
|
||||
test_oai_tool_calling_extraction()
|
||||
test_chat_completion()
|
||||
# test_aoai_chat_completion()
|
||||
# test_oai_tool_calling_extraction()
|
||||
# test_chat_completion()
|
||||
test_completion()
|
||||
# test_cost()
|
||||
test_usage_summary()
|
||||
# # test_cost()
|
||||
# test_usage_summary()
|
||||
|
|
|
@ -286,7 +286,9 @@ def test_chat_tools_stream() -> None:
|
|||
def test_completion_stream() -> None:
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
||||
# Azure can't have dot in model/deployment name
|
||||
model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct"
|
||||
response = client.create(prompt="1+1=", model=model, stream=True)
|
||||
print(response)
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
|
Loading…
Reference in New Issue