From 947cdab2a6894bc9c96bc023e715f5f4733fa954 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 13 Dec 2023 11:59:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=B2=A1=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E4=B9=9F=E5=8F=AF=E4=BB=A5=E9=97=AE=E7=AD=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/chat_message_serializers.py | 23 +++++++++++++++++-- .../serializers/chat_serializers.py | 12 ++++++---- pyproject.toml | 1 - 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 9a38b92..7d8c629 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -122,6 +122,7 @@ class ChatMessageSerializer(serializers.Serializer): vector.delete_by_paragraph_id(_value.get('paragraph_id')) title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content) + _id = str(uuid.uuid1()) embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get( 'id'), _value.get( @@ -130,6 +131,26 @@ class ChatMessageSerializer(serializers.Serializer): 'paragraph_id'), _value.get( 'source_type'), _value.get( 'source_id')) if _value is not None else (None, None, None, None, None, None) + + if chat_model is None: + def event_block_content(c: str): + yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None, + 'content': c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~'}) + "\n\n" + chat_info.append_chat_message( + ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id, + paragraph_id, + source_type, + source_id, c, 0, + 0)) + # 重新设置缓存 + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) + + r = StreamingHttpResponse(streaming_content=event_block_content(content), + content_type='text/event-stream;charset=utf-8') + + r['Cache-Control'] = 'no-cache' + return r # 获取上下文 history_message = chat_info.get_context_message() @@ -138,8 +159,6 @@ class ChatMessageSerializer(serializers.Serializer): # 对话 result_data = chat_model.stream(chat_message) - _id = str(uuid.uuid1()) - def event_content(response): all_text = '' try: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 7fd260a..43d53fe 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -93,14 +93,16 @@ class ChatSerializers(serializers.Serializer): self.is_valid(raise_exception=True) application_id = self.data.get('application_id') application = QuerySet(Application).get(id=application_id) - model = application.model + model = QuerySet(Model).filter(id=application.model_id).first() dataset_id_list = [str(row.dataset_id) for row in QuerySet(ApplicationDatasetMapping).filter( application_id=application_id)] - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - decrypt(model.credential)), - streaming=True) + chat_model = None + if model is not None: + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + decrypt(model.credential)), + streaming=True) chat_id = str(uuid.uuid1()) chat_cache.set(chat_id, diff --git a/pyproject.toml b/pyproject.toml index 5447e83..7107e0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ diskcache = "^5.6.3" pillow = "9.5.0" filetype = "^1.2.0" chardet = "^5.2.0" -torch = "^2.1.0" sentence-transformers = "^2.2.2" blinker = "^1.6.3" openai = "^0.28.1"