feat: 升级千帆模型版本,开放用户自定义模型
This commit is contained in:
parent
bcc9cb1132
commit
e80257a978
|
@ -10,7 +10,7 @@ import os
|
|||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_community.chat_models import AzureChatOpenAI
|
||||
from langchain_community.chat_models.azure_openai import AzureChatOpenAI
|
||||
|
||||
from common import froms
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
@ -29,9 +29,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
if model_name not in model_dict:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
|
||||
|
||||
for key in ['api_base', 'api_key', 'deployment_name']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
|
@ -40,7 +37,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
try:
|
||||
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='valid')])
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
|
@ -61,8 +58,48 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
deployment_name = froms.TextInputField("部署名", required=True)
|
||||
|
||||
|
||||
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = AzureModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_version = froms.TextInputField("api_version", required=True)
|
||||
|
||||
api_base = froms.TextInputField('API 域名', required=True)
|
||||
|
||||
api_key = froms.PasswordInputField("API Key", required=True)
|
||||
|
||||
deployment_name = froms.TextInputField("部署名", required=True)
|
||||
|
||||
|
||||
azure_llm_model_credential = AzureLLMModelCredential()
|
||||
|
||||
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
|
||||
api_version='2023-07-01-preview'),
|
||||
|
@ -84,18 +121,18 @@ class AzureModelProvider(IModelProvider):
|
|||
model_info: ModelInfo = model_dict.get(model_name)
|
||||
azure_chat_open_ai = AzureChatOpenAI(
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_version=model_info.api_version,
|
||||
openai_api_version=model_credential.get(
|
||||
'api_version') if 'api_version' in model_credential else model_info.api_version,
|
||||
deployment_name=model_credential.get('deployment_name'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_type="azure",
|
||||
tiktoken_model_name=model_name
|
||||
openai_api_type="azure"
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
raise AppApiException(500, f'不支持的模型:{model_name}')
|
||||
return base_azure_llm_model_credential
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
||||
|
|
|
@ -9,8 +9,9 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
from qianfan import ChatCompletion
|
||||
|
||||
from common import froms
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
@ -27,10 +28,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model_type_list = WenxinModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
if model_name not in model_dict:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
|
||||
|
||||
model_info = [model.lower() for model in ChatCompletion.models()]
|
||||
if not model_info.__contains__(model_name.lower()):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
|
||||
for key in ['api_key', 'secret_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
|
@ -39,10 +39,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
try:
|
||||
WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke(
|
||||
[HumanMessage(content='valid')])
|
||||
[HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, "校验失败,请检查 api_key secret_key 是否正确")
|
||||
raise e
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
|
@ -121,7 +120,7 @@ class WenxinModelProvider(IModelProvider):
|
|||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
raise AppApiException(500, f'不支持的模型:{model_name}')
|
||||
return win_xin_llm_model_credential
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(
|
||||
|
|
|
@ -23,7 +23,7 @@ sentence-transformers = "^2.2.2"
|
|||
blinker = "^1.6.3"
|
||||
openai = "^1.13.3"
|
||||
tiktoken = "^0.5.1"
|
||||
qianfan = "^0.1.1"
|
||||
qianfan = "^0.3.6.1"
|
||||
pycryptodome = "^3.19.0"
|
||||
beautifulsoup4 = "^4.12.2"
|
||||
html2text = "^2024.2.26"
|
||||
|
|
Loading…
Reference in New Issue