feat: 升级千帆模型版本,开放用户自定义模型

This commit is contained in:
shaohuzhang1 2024-03-25 15:08:40 +08:00
parent bcc9cb1132
commit e80257a978
3 changed files with 55 additions and 19 deletions

View File

@ -10,7 +10,7 @@ import os
from typing import Dict
from langchain.schema import HumanMessage
from langchain_community.chat_models import AzureChatOpenAI
from langchain_community.chat_models.azure_openai import AzureChatOpenAI
from common import froms
from common.exception.app_exception import AppApiException
@ -29,9 +29,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
if model_name not in model_dict:
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
for key in ['api_base', 'api_key', 'deployment_name']:
if key not in model_credential:
if raise_exception:
@ -40,7 +37,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
return False
try:
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='valid')])
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
@ -61,8 +58,48 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
deployment_name = froms.TextInputField("部署名", required=True)
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = AzureModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_version = froms.TextInputField("api_version", required=True)
api_base = froms.TextInputField('API 域名', required=True)
api_key = froms.PasswordInputField("API Key", required=True)
deployment_name = froms.TextInputField("部署名", required=True)
azure_llm_model_credential = AzureLLMModelCredential()
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
model_dict = {
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
@ -84,18 +121,18 @@ class AzureModelProvider(IModelProvider):
model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatOpenAI(
openai_api_base=model_credential.get('api_base'),
openai_api_version=model_info.api_version,
openai_api_version=model_credential.get(
'api_version') if 'api_version' in model_credential else model_info.api_version,
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure",
tiktoken_model_name=model_name
openai_api_type="azure"
)
return azure_chat_open_ai
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
raise AppApiException(500, f'不支持的模型:{model_name}')
return base_azure_llm_model_credential
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(

View File

@ -9,8 +9,9 @@
import os
from typing import Dict
from langchain_community.chat_models import QianfanChatEndpoint
from langchain.schema import HumanMessage
from langchain_community.chat_models import QianfanChatEndpoint
from qianfan import ChatCompletion
from common import froms
from common.exception.app_exception import AppApiException
@ -27,10 +28,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
model_type_list = WenxinModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
if model_name not in model_dict:
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
model_info = [model.lower() for model in ChatCompletion.models()]
if not model_info.__contains__(model_name.lower()):
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
for key in ['api_key', 'secret_key']:
if key not in model_credential:
if raise_exception:
@ -39,10 +39,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
return False
try:
WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke(
[HumanMessage(content='valid')])
[HumanMessage(content='你好')])
except Exception as e:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, "校验失败,请检查 api_key secret_key 是否正确")
raise e
return True
def encryption_dict(self, model_info: Dict[str, object]):
@ -121,7 +120,7 @@ class WenxinModelProvider(IModelProvider):
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
raise AppApiException(500, f'不支持的模型:{model_name}')
return win_xin_llm_model_credential
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(

View File

@ -23,7 +23,7 @@ sentence-transformers = "^2.2.2"
blinker = "^1.6.3"
openai = "^1.13.3"
tiktoken = "^0.5.1"
qianfan = "^0.1.1"
qianfan = "^0.3.6.1"
pycryptodome = "^3.19.0"
beautifulsoup4 = "^4.12.2"
html2text = "^2024.2.26"