!60928 PromptKVCache & DecoderKVCache

Merge pull request !60928 from ling/sr
This commit is contained in:
i-robot 2023-12-12 11:46:02 +00:00 committed by Gitee
commit 2c6ff963f0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 1171 additions and 0 deletions

View File

@ -0,0 +1,144 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/decoder_k_v_cache_info.h"
#include "frontend/parallel/dynamic_creator.h"
namespace mindspore {
namespace parallel {
// DecoderKVCache seven inputs
// cache: (batch_size, num_head, max_seq_len, hidden_size)
// update: (batch_size, num_head, update_seq_len, hidden_size)
// valid_seq_len: (batch)
// batch_index: (1)
// seq_len_axis: (1)
// new_max_seq_len: (1)
// cur_max_seq_len: (1)
// ------------------------------
// output: (batch_size, num_head, max_seq_len, hidden_size)
// split strategy
// batch_size is able to split.
// max_seq_len, update_seq_len is not able to split.
// num_head is able to split.
// hidden_size is able to split.
Status DecoderKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
}
auto input_strategys = strategy->GetInputDim();
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
auto strategy_batch_index = input_strategys.at(3); // (1)
auto strategy_seq_len_axis = input_strategys.at(4); // (1)
auto strategy_new_max_seq_len = input_strategys.at(5); // (1)
auto strategy_cur_max_seq_len = input_strategys.at(6); // (1)
if (strategy_new_max_seq_len.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The new_max_seq_len can't be shard, but got"
<< " new_max_seq_len's strategy: " << strategy_new_max_seq_len;
return FAILED;
}
if (strategy_cur_max_seq_len.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The cur_max_seq_len can't be shard, but got"
<< " cur_max_seq_len's strategy: " << strategy_cur_max_seq_len;
return FAILED;
}
if (strategy_batch_index.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The batch_index can't be shard, but got"
<< " batch_index's strategy: " << strategy_batch_index;
return FAILED;
}
if (strategy_seq_len_axis.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len_axis can't be shard, but got"
<< " seq_len_axis's strategy: " << strategy_seq_len_axis;
return FAILED;
}
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(2)
<< "; update's seq_len strategy: " << strategy_update.at(2);
return FAILED;
}
// batch_size must be the same strategy.
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
return FAILED;
}
// num_head must be the same strategy.
if (strategy_cache.at(1) != strategy_cache.at(1)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
// hidden_size must be the same strategy.
if (strategy_cache.at(3) != strategy_update.at(3)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
}
Status DecoderKVCacheInfo::InferDevMatrixShape() {
auto input_strategys = strategy()->GetInputDim();
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
// update_seq_len batch_size num_head max_seq_len hidden_size
// 4 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
return SUCCESS;
}
Status DecoderKVCacheInfo::InferTensorMap() {
Shape cache_tensor_map{3, 2, 1, 0};
Shape update_tensor_map{3, 2, 4, 0};
Shape valid_seq_len_tensor_map{3};
Shape batch_index_tensor_map{-1};
Shape seq_lqn_axis_tensor_map{-1};
Shape new_max_seq_len_tensor_map{-1};
Shape cur_max_seq_len_tensor_map{-1};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
Shape out_tensor_map{3, 2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
return SUCCESS;
}
REGISTER(DecoderKVCacheInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DECODER_K_V_CACHE_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DECODER_K_V_CACHE_INFO_H_
#include <memory>
#include <string>
#include <vector>
#include "utils/hash_map.h"
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class DecoderKVCacheInfo : public OperatorInfo {
public:
DecoderKVCacheInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {}
~DecoderKVCacheInfo() override = default;
Status CheckStrategy(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override { return {}; }
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
protected:
Status GetAttrs() override { return SUCCESS; }
Status InferForwardCommunication() { return SUCCESS; }
Status InferTensorMap() override;
Status InferDevMatrixShape() override;
};
using DecoderKVCacheInfoPtr = std::shared_ptr<DecoderKVCacheInfo>;
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DECODER_K_V_CACHE_INFO_H_

View File

@ -0,0 +1,145 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/prompt_k_v_cache_info.h"
#include "frontend/parallel/dynamic_creator.h"
namespace mindspore {
namespace parallel {
// PromptKVCache seven inputs
// cache: (batch_size, num_head, max_seq_len, hidden_size)
// update: (batch_size, num_head, update_seq_len, hidden_size)
// valid_seq_len: (batch)
// batch_index: (1)
// seq_len_axis: (1)
// new_max_seq_len: (1)
// cur_max_seq_len: (1)
// ------------------------------
// output: (batch_size, num_head, max_seq_len, hidden_size)
// split strategy
// batch_size is able to split.
// max_seq_len, update_seq_len is not able to split.
// num_head is able to split.
// hidden_size is able to split.
Status PromptKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
}
auto input_strategys = strategy->GetInputDim();
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
auto strategy_batch_index = input_strategys.at(3); // (1)
auto strategy_seq_len_axis = input_strategys.at(4); // (1)
auto strategy_new_max_seq_len = input_strategys.at(5); // (1)
auto strategy_cur_max_seq_len = input_strategys.at(6); // (1)
if (strategy_new_max_seq_len.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The new_max_seq_len can't be shard, but got"
<< " new_max_seq_len's strategy: " << strategy_new_max_seq_len;
return FAILED;
}
if (strategy_cur_max_seq_len.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The cur_max_seq_len can't be shard, but got"
<< " cur_max_seq_len's strategy: " << strategy_cur_max_seq_len;
return FAILED;
}
if (strategy_batch_index.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The batch_index can't be shard, but got"
<< " batch_index's strategy: " << strategy_batch_index;
return FAILED;
}
if (strategy_seq_len_axis.at(0) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len_axis can't be shard, but got"
<< " seq_len_axis's strategy: " << strategy_seq_len_axis;
return FAILED;
}
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(2)
<< "; update's seq_len strategy: " << strategy_update.at(2);
return FAILED;
}
// batch_size must be the same strategy.
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
return FAILED;
}
// num_head must be the same strategy.
if (strategy_cache.at(1) != strategy_cache.at(1)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
// hidden_size must be the same strategy.
if (strategy_cache.at(3) != strategy_update.at(3)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
}
Status PromptKVCacheInfo::InferDevMatrixShape() {
auto input_strategys = strategy()->GetInputDim();
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
// update_seq_len batch_size num_head max_seq_len hidden_size
// 4 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
return SUCCESS;
}
Status PromptKVCacheInfo::InferTensorMap() {
Shape cache_tensor_map{3, 2, 1, 0};
Shape update_tensor_map{3, 2, 4, 0};
Shape valid_seq_len_tensor_map{3};
Shape batch_index_tensor_map{-1};
Shape seq_lqn_axis_tensor_map{-1};
Shape new_max_seq_len_tensor_map{-1};
Shape cur_max_seq_len_tensor_map{-1};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
Shape out_tensor_map{3, 2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
return SUCCESS;
}
REGISTER(PromptKVCacheInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PROMPT_K_V_CACHE_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PROMPT_K_V_CACHE_INFO_H_
#include <memory>
#include <string>
#include <vector>
#include "utils/hash_map.h"
#include "ir/value.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
/*
* parallel class for PromptKVCache Primitive
*/
class PromptKVCacheInfo : public OperatorInfo {
public:
PromptKVCacheInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {}
~PromptKVCacheInfo() override = default;
Status CheckStrategy(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override { return {}; }
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
protected:
Status GetAttrs() override { return SUCCESS; }
Status InferForwardCommunication() { return SUCCESS; }
Status InferTensorMap() override;
Status InferDevMatrixShape() override;
};
using PromptKVCacheInfoPtr = std::shared_ptr<PromptKVCacheInfo>;
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PRELU_INFO_H_

View File

@ -31,6 +31,28 @@ REG_OP(KVCacheMgr)
.OUTPUT(past, TensorType({DT_FLOAT16}))
.OP_END_FACTORY_REG(KVCacheMgr)
REG_OP(DecoderKvCache)
.INPUT(cache, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
.INPUT(update, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
.INPUT(valid_seq_len, TensorType({DT_INT64}))
.INPUT(batch_index, TensorType({DT_INT64}))
.INPUT(seq_len_axis, TensorType({DT_INT64}))
.INPUT(new_max_seq_len, TensorType({DT_INT64}))
.INPUT(cur_max_seq_len, TensorType({DT_INT64}))
.OUTPUT(out, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
.OP_END_FACTORY_REG(DecoderKvCache)
REG_OP(PromptKvCache)
.INPUT(cache, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
.INPUT(update, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
.INPUT(valid_seq_len, TensorType({DT_INT64}))
.INPUT(batch_index, TensorType({DT_INT64}))
.INPUT(seq_len_axis, TensorType({DT_INT64}))
.INPUT(new_max_seq_len, TensorType({DT_INT64}))
.INPUT(cur_max_seq_len, TensorType({DT_INT64}))
.OUTPUT(out, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
.OP_END_FACTORY_REG(PromptKvCache)
REG_CUST_OP(NoRepeatNGram)
.INPUT(state_seq, TensorType({DT_INT32}))
.INPUT(log_probs, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))

View File

@ -26,6 +26,24 @@ ATTR_MAP(KVCacheMgr) = EMPTY_ATTR_MAP;
OUTPUT_MAP(KVCacheMgr) = {{0, OUTPUT_DESC(past)}};
REG_ADPT_DESC(KVCacheMgr, "KVCacheMgr", ADPT_DESC(KVCacheMgr))
// DecoderKVCache
INPUT_MAP(DecoderKvCache) = {{1, INPUT_DESC(cache)}, {2, INPUT_DESC(update)},
{3, INPUT_DESC(valid_seq_len)}, {4, INPUT_DESC(batch_index)},
{5, INPUT_DESC(seq_len_axis)}, {6, INPUT_DESC(new_max_seq_len)},
{7, INPUT_DESC(cur_max_seq_len)}};
ATTR_MAP(DecoderKvCache) = EMPTY_ATTR_MAP;
OUTPUT_MAP(DecoderKvCache) = {{0, OUTPUT_DESC(out)}};
REG_ADPT_DESC(DecoderKvCache, "DecoderKVCache", ADPT_DESC(DecoderKvCache))
// PromptKVCache
INPUT_MAP(PromptKvCache) = {{1, INPUT_DESC(cache)}, {2, INPUT_DESC(update)},
{3, INPUT_DESC(valid_seq_len)}, {4, INPUT_DESC(batch_index)},
{5, INPUT_DESC(seq_len_axis)}, {6, INPUT_DESC(new_max_seq_len)},
{7, INPUT_DESC(cur_max_seq_len)}};
ATTR_MAP(PromptKvCache) = EMPTY_ATTR_MAP;
OUTPUT_MAP(PromptKvCache) = {{0, OUTPUT_DESC(out)}};
REG_ADPT_DESC(PromptKvCache, "PromptKVCache", ADPT_DESC(PromptKvCache))
// FlashAttention
INPUT_MAP(FlashAttention) = {
{1, INPUT_DESC(q)}, {2, INPUT_DESC(k)}, {3, INPUT_DESC(v)}, {4, INPUT_DESC(attention_mask)}};

View File

@ -27,10 +27,16 @@
DECLARE_OP_ADAPTER(KVCacheMgr)
DECLARE_OP_USE_OUTPUT(KVCacheMgr)
DECLARE_OP_ADAPTER(DecoderKvCache)
DECLARE_OP_USE_OUTPUT(DecoderKvCache)
DECLARE_OP_ADAPTER(FlashAttention)
DECLARE_OP_USE_OUTPUT(FlashAttention)
DECLARE_OP_ADAPTER(FFN)
DECLARE_OP_USE_OUTPUT(FFN)
DECLARE_OP_ADAPTER(PromptKvCache)
DECLARE_OP_USE_OUTPUT(PromptKvCache)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_TRANSFORM_FUSION_OPS_DECLARE_H_

View File

@ -0,0 +1,52 @@
decoder_k_v_cache:
description: |
The DecoderKVCache is used for decoding the KVCache of transformer network.
Args:
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
When seq_len_axis is 2, cache tensor of shape
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
When seq_len_axis is 1, cache tensor of shape
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
When seq_len_axis is 2, update tensor of shape
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
When seq_len_axis is 1, update tensor of shape
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
Valid_seq_len tensor of shape :math:`(batch\_size)`.
batch_index (Tensor): The batch_index tensor with data type of int64.
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
seq_len_axis (Tensor): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
New_max_seq_len tensor of shape :math:`(1)`.
Indicate that user want to change the shape of cache tensor from
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
:math:
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
Outputs:
With same data type and same shape as `cache` tensor.
Supported Platforms:
``Ascend``
Examples:
>>> from mindspore.ops.operations import _inner_ops
>>> b = 4
>>> h = 40
>>> max_s = 1024
>>> s = 1
>>> d = 128
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
>>> decoder_kv_cache = _inner_ops.DecoderKVCache()
>>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, Tensor(2), new_max_seq_len, cur_max_seq_len)
>>> print(cache)

View File

@ -0,0 +1,22 @@
#operator decoder_k_v_cache
decoder_k_v_cache:
args:
cache:
dtype: tensor
update:
dtype: tensor
valid_seq_len:
dtype: tensor
batch_index:
dtype: tensor
seq_len_axis:
dtype: tensor
new_max_seq_len:
dtype: tensor
cur_max_seq_len:
dtype: tensor
labels:
side_effect_mem: True
returns:
out:
dtype: tensor

View File

@ -0,0 +1,55 @@
prompt_k_v_cache:
description: |
The PromptKVCache is used for prefill the KVCache of transformer network.
Args:
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
When seq_len_axis is 2, cache tensor of shape
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
When seq_len_axis is 1, cache tensor of shape
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
When seq_len_axis is 2, update tensor of shape
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
When seq_len_axis is 1, update tensor of shape
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
Valid_seq_len tensor of shape :math:`(batch\_size)`.
batch_index (Tensor): The batch_index tensor with data type of int64.
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
seq_len_axis (Tensor): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
New_max_seq_len tensor of shape :math:`(1)`.
Indicate that user want to change the shape of cache tensor from
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
:math:
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
Outputs:
With same data type and same shape as `cache` tensor.
Supported Platforms:
``Ascend``
Examples:
>>> from mindspore import Tensor
>>> from mindspore.ops.operations import _inner_ops
>>> b = 4
>>> h = 40
>>> max_s = 1024
>>> s = 256
>>> d = 128
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
>>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
>>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, Tensor(2), new_max_seq_len, cur_max_seq_len)
>>> print(cache)

View File

@ -0,0 +1,27 @@
#operator prompt_k_v_cache
prompt_k_v_cache:
args:
cache:
dtype: tensor
update:
dtype: tensor
valid_seq_len:
dtype: tensor
batch_index:
dtype: tensor
seq_len_axis:
dtype: tensor
new_max_seq_len:
dtype: tensor
cur_max_seq_len:
dtype: tensor
align_mode:
dtype: int
default: 0
prim_init: True
arg_handler: k_v_cache_align_mode_to_enum
labels:
side_effect_mem: True
returns:
out:
dtype: tensor

View File

@ -0,0 +1,31 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/ops_func_impl/decoder_k_v_cache.h"
namespace mindspore {
namespace ops {
BaseShapePtr DecoderKVCacheFuncImpl::InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
return input_args[0]->GetShape()->Clone();
}
TypePtr DecoderKVCacheFuncImpl::InferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
return input_args[0]->GetType()->Clone();
}
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_DECODER_K_V_CACHE_H_
#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_DECODER_K_V_CACHE_H_
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "mindapi/base/types.h"
#include "ops/base_operator.h"
#include "ops/ops_func_impl/op_func_impl.h"
namespace mindspore {
namespace ops {
class MIND_API DecoderKVCacheFuncImpl : public OpFuncImpl {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_DECODER_K_V_CACHE_H_

View File

@ -0,0 +1,31 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/ops_func_impl/prompt_k_v_cache.h"
namespace mindspore {
namespace ops {
BaseShapePtr PromptKVCacheFuncImpl::InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
return input_args[0]->GetShape()->Clone();
}
TypePtr PromptKVCacheFuncImpl::InferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
return input_args[0]->GetType()->Clone();
}
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_PROMPT_K_V_CACHE_H_
#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_PROMPT_K_V_CACHE_H_
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "mindapi/base/types.h"
#include "ops/base_operator.h"
#include "ops/ops_func_impl/op_func_impl.h"
namespace mindspore {
namespace ops {
class MIND_API PromptKVCacheFuncImpl : public OpFuncImpl {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_PROMPT_K_V_CACHE_H_

View File

@ -0,0 +1,154 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test DecoderKVCache plugin custom ops.
"""
import os
import numpy as np
import mindspore_lite as mslite
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, context
from mindspore.train.serialization import export
from mindspore.ops.operations._inner_ops import DecoderKVCache
class DecoderKVCacheNet(nn.Cell):
"""
DecoderKVCacheNet.
"""
def __init__(self):
super().__init__()
self.add = ops.Add()
self.sub = ops.Sub()
self.decoder_k_v_cache = DecoderKVCache()
self.seq_len_axis = [2, 0, 0, 0]
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
out = self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, self.seq_len_axis,
new_max_seq_len, cur_max_seq_len)
add_out = self.add(cache, 1)
sub_out = self.sub(add_out, 1)
return sub_out
def np_inference(cache, update, valid_seq_len):
"""
np_inference
"""
ans = cache.copy()
for b in range(cache.shape[0]):
bs_idx = valid_seq_len[b]
ans[b, :, bs_idx, :] = update[b, :, 0, :]
return ans
def create_numpy_inputs():
"""
create inputs
"""
cache = np.random.rand(4, 2, 4, 32).astype(np.float16)
update = np.random.rand(4, 2, 1, 32).astype(np.float16)
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
batch_index = np.array([1]).astype(np.int64)
seq_len_axis = np.array([2]).astype(np.int64)
new_max_seq_len = np.array([4]).astype(np.int64)
cur_max_seq_len = np.array([1]).astype(np.int64)
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
def create_ms_inputs():
"""
create inputs
"""
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
ms_cache = Tensor(cache)
ms_update = Tensor(update)
ms_valid_seq_len = Tensor(valid_seq_len)
ms_batch_index = Tensor(batch_index)
ms_seq_len_axis = Tensor(seq_len_axis)
ms_new_max_seq_len = Tensor(new_max_seq_len)
ms_cur_max_seq_len = Tensor(cur_max_seq_len)
return (ms_cache, ms_update, ms_valid_seq_len, ms_batch_index, ms_seq_len_axis,
ms_new_max_seq_len, ms_cur_max_seq_len)
def create_np_inputs(cache, update, valid_seq_len):
"""
create_np_inputs
"""
return (cache, update, valid_seq_len)
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
"""
create_lite_inputs
"""
cache = mslite.Tensor(cache)
update = mslite.Tensor(update)
valid_seq_len = mslite.Tensor(valid_seq_len)
batch_index = mslite.Tensor(batch_index)
seq_len_axis = mslite.Tensor(seq_len_axis)
new_max_seq_len = mslite.Tensor(new_max_seq_len)
cur_max_seq_len = mslite.Tensor(cur_max_seq_len)
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
def inference_decoder_k_v_cache(mindir_model):
"""
inference model
"""
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
lite_ctx1 = mslite.Context()
lite_ctx1.target = ["ascend"]
lite_ctx1.ascend.device_id = 1
lite_ctx1.ascend.provider = "ge"
model = mslite.Model()
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len))
mslite_output = model.predict(input_lists)
np_cache, np_update, np_valid_seq_len = create_np_inputs(cache, update, valid_seq_len)
expect_output = np_inference(np_cache, np_update, np_valid_seq_len)
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
def export_decoder_k_v_cache_model():
"""
export model
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
net = DecoderKVCacheNet()
file_name = "decoder_k_v_cache_primitive"
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
file_name=file_name, file_format='MINDIR')
model_name = file_name + ".mindir"
assert os.path.exists(model_name)
return model_name
if __name__ == '__main__':
model_path = export_decoder_k_v_cache_model()
print("decoder_k_v_cache st : export success path: ", model_path)
inference_decoder_k_v_cache(model_path)
print("decoder_k_v_cache st : inference success.")

View File

@ -0,0 +1,158 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test PromptKVCache plugin custom ops.
"""
import os
import numpy as np
import mindspore_lite as mslite
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, context, Parameter
from mindspore.train.serialization import export
import mindspore.common.dtype as mstype
from mindspore.ops.operations._inner_ops import PromptKVCache
class PromptKVCacheNet(nn.Cell):
"""
PromptKVCacheNet.
"""
def __init__(self, padding_mode):
super().__init__()
self.sub = ops.Sub()
self.add = ops.Add()
self.concat_dim2 = ops.Concat(axis=2)
self.prompt_k_v_cache = PromptKVCache(padding_mode)
self.pad_update_zero_tensor = Parameter(ops.zeros((1, 2, 3, 32), mstype.float16))
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
update_pad = self.concat_dim2((update, self.pad_update_zero_tensor))
out = self.prompt_k_v_cache(cache, update_pad, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len)
add_out = self.add(cache, 1)
sub_out = self.sub(add_out, 1)
return sub_out
def np_inference(cache, update, batch_index):
"""
np_inference
"""
zeros_ans = np.zeros(cache.shape, cache.dtype)
us = update.shape[2]
cache[batch_index, :] = zeros_ans[batch_index, :]
cache[batch_index, :, 0:us, :] = update
return cache
def create_numpy_inputs():
"""
create inputs
"""
cache = np.zeros([4, 2, 4, 32]).astype(np.float16)
cache = cache + 9
update = np.ones([1, 2, 1, 32]).astype(np.float16)
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
batch_index = np.array([1]).astype(np.int64)
seq_len_axis = np.array([2]).astype(np.int64)
new_max_seq_len = np.array([4]).astype(np.int64)
cur_max_seq_len = np.array([1]).astype(np.int64)
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
def create_ms_inputs():
"""
create inputs
"""
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
ms_cache = Tensor(cache)
ms_update = Tensor(update)
ms_valid_seq_len = Tensor(valid_seq_len)
ms_batch_index = Tensor(batch_index)
ms_seq_len_axis = Tensor(seq_len_axis)
ms_new_max_seq_len = Tensor(new_max_seq_len)
ms_cur_max_seq_len = Tensor(cur_max_seq_len)
return (ms_cache, ms_update, ms_valid_seq_len, ms_batch_index, ms_seq_len_axis,
ms_new_max_seq_len, ms_cur_max_seq_len)
def create_np_inputs(cache, update, batch_index):
"""
create_np_inputs
"""
return (cache, update, batch_index)
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
"""
create_lite_inputs
"""
cache = mslite.Tensor(cache)
update = mslite.Tensor(update)
valid_seq_len = mslite.Tensor(valid_seq_len)
batch_index = mslite.Tensor(batch_index)
seq_len_axis = mslite.Tensor(seq_len_axis)
new_max_seq_len = mslite.Tensor(new_max_seq_len)
cur_max_seq_len = mslite.Tensor(cur_max_seq_len)
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
def inference_prompt_k_v_cache(mindir_model):
"""
inference model
"""
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
lite_ctx1 = mslite.Context()
lite_ctx1.target = ["ascend"]
lite_ctx1.ascend.device_id = 1
lite_ctx1.ascend.provider = "ge"
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len))
model = mslite.Model()
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
mslite_output = model.predict(input_lists)
np_cache, np_update, batch_index = create_np_inputs(cache, update, batch_index)
expect_output = np_inference(np_cache, np_update, batch_index)
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
def export_prompt_k_v_cache_model():
"""
export model
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
net = PromptKVCacheNet("right")
file_name = "prompt_k_v_cache_primitive"
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
file_name=file_name, file_format='MINDIR')
model_name = file_name + ".mindir"
assert os.path.exists(model_name)
return model_name
if __name__ == '__main__':
model_path = export_prompt_k_v_cache_model()
print("prompt_k_v_cache st : export success path: ", model_path)
inference_prompt_k_v_cache(model_path)
print("prompt_k_v_cache st : inference success.")

View File

@ -235,3 +235,7 @@ coordinate_transformation_mode:
ALIGN_CORNERS: 0
HALF_PIXEL: 1
# enum KVCacheAlignMode
k_v_cache_align_mode:
RIGHT: 0
LEFT: 1

View File

@ -0,0 +1,58 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore.nn import Cell
from mindspore import context, Tensor, Parameter
from mindspore.ops.auto_generate import DecoderKVCache
from parallel.utils.utils import ParallelValidator, compile_net
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
class DecoderKVCacheNet(Cell):
def __init__(self, strategy):
super(DecoderKVCacheNet, self).__init__()
self.decoder_k_v_cache = DecoderKVCache().shard(strategy)
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
return self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len)
def test_decoder_k_v_cache_net():
"""
Feature: test decoder_k_v_cache auto parallel
Description: auto parallel
Expectation: shape is as expected.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((4, 1, 1, 1,), (4, 1, 1, 1,), (4,), (1,), (1,), (1,), (1,))
net = DecoderKVCacheNet(strategy)
cache = Parameter(Tensor(np.ones([4, 40, 1024, 128]), dtype=ms.float16), "cache")
update = Parameter(Tensor(np.ones([4, 40, 1, 128]), dtype=ms.float16), "update")
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
validator = ParallelValidator(net, phase)
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
assert validator.check_parameter_shape('valid_seq_len', [1])

View File

@ -0,0 +1,59 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore.nn import Cell
from mindspore import context, Tensor, Parameter
from mindspore.ops.auto_generate import PromptKVCache
from parallel.utils.utils import ParallelValidator, compile_net
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
class PromptKVCacheNet(Cell):
def __init__(self, padding_mode, strategy):
super(PromptKVCacheNet, self).__init__()
self.prompt_k_v_cache = PromptKVCache(padding_mode).shard(strategy)
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
return self.prompt_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len)
def test_prompt_k_v_cache_net():
"""
Feature: test prompt_k_v_cache auto parallel
Description: auto parallel
Expectation: shape is as expected.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((4, 1, 1, 1,), (4, 1, 1, 1,), (4,), (1,), (1,), (1,), (1,))
padding_mode = "right"
net = PromptKVCacheNet(padding_mode, strategy)
cache = Parameter(Tensor(np.ones([4, 40, 1024, 128]), dtype=ms.float16), "cache")
update = Parameter(Tensor(np.ones([4, 40, 1, 128]), dtype=ms.float16), "update")
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
validator = ParallelValidator(net, phase)
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
assert validator.check_parameter_shape('valid_seq_len', [1])