!60928 PromptKVCache & DecoderKVCache
Merge pull request !60928 from ling/sr
This commit is contained in:
commit
2c6ff963f0
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/decoder_k_v_cache_info.h"
|
||||
#include "frontend/parallel/dynamic_creator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// DecoderKVCache seven inputs
|
||||
// cache: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
// update: (batch_size, num_head, update_seq_len, hidden_size)
|
||||
// valid_seq_len: (batch)
|
||||
// batch_index: (1)
|
||||
// seq_len_axis: (1)
|
||||
// new_max_seq_len: (1)
|
||||
// cur_max_seq_len: (1)
|
||||
// ------------------------------
|
||||
// output: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
|
||||
// split strategy
|
||||
// batch_size is able to split.
|
||||
// max_seq_len, update_seq_len is not able to split.
|
||||
// num_head is able to split.
|
||||
// hidden_size is able to split.
|
||||
|
||||
Status DecoderKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
auto input_strategys = strategy->GetInputDim();
|
||||
|
||||
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
|
||||
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
|
||||
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
|
||||
auto strategy_batch_index = input_strategys.at(3); // (1)
|
||||
auto strategy_seq_len_axis = input_strategys.at(4); // (1)
|
||||
auto strategy_new_max_seq_len = input_strategys.at(5); // (1)
|
||||
auto strategy_cur_max_seq_len = input_strategys.at(6); // (1)
|
||||
|
||||
if (strategy_new_max_seq_len.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The new_max_seq_len can't be shard, but got"
|
||||
<< " new_max_seq_len's strategy: " << strategy_new_max_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_cur_max_seq_len.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The cur_max_seq_len can't be shard, but got"
|
||||
<< " cur_max_seq_len's strategy: " << strategy_cur_max_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_batch_index.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The batch_index can't be shard, but got"
|
||||
<< " batch_index's strategy: " << strategy_batch_index;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_seq_len_axis.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len_axis can't be shard, but got"
|
||||
<< " seq_len_axis's strategy: " << strategy_seq_len_axis;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(2)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(2);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// batch_size must be the same strategy.
|
||||
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update
|
||||
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// num_head must be the same strategy.
|
||||
if (strategy_cache.at(1) != strategy_cache.at(1)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// hidden_size must be the same strategy.
|
||||
if (strategy_cache.at(3) != strategy_update.at(3)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DecoderKVCacheInfo::InferDevMatrixShape() {
|
||||
auto input_strategys = strategy()->GetInputDim();
|
||||
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
|
||||
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
|
||||
|
||||
// update_seq_len batch_size num_head max_seq_len hidden_size
|
||||
// 4 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DecoderKVCacheInfo::InferTensorMap() {
|
||||
Shape cache_tensor_map{3, 2, 1, 0};
|
||||
Shape update_tensor_map{3, 2, 4, 0};
|
||||
Shape valid_seq_len_tensor_map{3};
|
||||
Shape batch_index_tensor_map{-1};
|
||||
Shape seq_lqn_axis_tensor_map{-1};
|
||||
Shape new_max_seq_len_tensor_map{-1};
|
||||
Shape cur_max_seq_len_tensor_map{-1};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
|
||||
|
||||
Shape out_tensor_map{3, 2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
REGISTER(DecoderKVCacheInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DECODER_K_V_CACHE_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DECODER_K_V_CACHE_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class DecoderKVCacheInfo : public OperatorInfo {
|
||||
public:
|
||||
DecoderKVCacheInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {}
|
||||
~DecoderKVCacheInfo() override = default;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override { return {}; }
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status InferForwardCommunication() { return SUCCESS; }
|
||||
Status InferTensorMap() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
};
|
||||
using DecoderKVCacheInfoPtr = std::shared_ptr<DecoderKVCacheInfo>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_DECODER_K_V_CACHE_INFO_H_
|
|
@ -0,0 +1,145 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/prompt_k_v_cache_info.h"
|
||||
#include "frontend/parallel/dynamic_creator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// PromptKVCache seven inputs
|
||||
// cache: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
// update: (batch_size, num_head, update_seq_len, hidden_size)
|
||||
// valid_seq_len: (batch)
|
||||
// batch_index: (1)
|
||||
// seq_len_axis: (1)
|
||||
// new_max_seq_len: (1)
|
||||
// cur_max_seq_len: (1)
|
||||
// ------------------------------
|
||||
// output: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
|
||||
// split strategy
|
||||
// batch_size is able to split.
|
||||
// max_seq_len, update_seq_len is not able to split.
|
||||
// num_head is able to split.
|
||||
// hidden_size is able to split.
|
||||
|
||||
Status PromptKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto input_strategys = strategy->GetInputDim();
|
||||
|
||||
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
|
||||
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
|
||||
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
|
||||
auto strategy_batch_index = input_strategys.at(3); // (1)
|
||||
auto strategy_seq_len_axis = input_strategys.at(4); // (1)
|
||||
auto strategy_new_max_seq_len = input_strategys.at(5); // (1)
|
||||
auto strategy_cur_max_seq_len = input_strategys.at(6); // (1)
|
||||
|
||||
if (strategy_new_max_seq_len.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The new_max_seq_len can't be shard, but got"
|
||||
<< " new_max_seq_len's strategy: " << strategy_new_max_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_cur_max_seq_len.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The cur_max_seq_len can't be shard, but got"
|
||||
<< " cur_max_seq_len's strategy: " << strategy_cur_max_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_batch_index.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The batch_index can't be shard, but got"
|
||||
<< " batch_index's strategy: " << strategy_batch_index;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_seq_len_axis.at(0) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len_axis can't be shard, but got"
|
||||
<< " seq_len_axis's strategy: " << strategy_seq_len_axis;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(2)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(2);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// batch_size must be the same strategy.
|
||||
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update
|
||||
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// num_head must be the same strategy.
|
||||
if (strategy_cache.at(1) != strategy_cache.at(1)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// hidden_size must be the same strategy.
|
||||
if (strategy_cache.at(3) != strategy_update.at(3)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status PromptKVCacheInfo::InferDevMatrixShape() {
|
||||
auto input_strategys = strategy()->GetInputDim();
|
||||
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
|
||||
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
|
||||
|
||||
// update_seq_len batch_size num_head max_seq_len hidden_size
|
||||
// 4 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status PromptKVCacheInfo::InferTensorMap() {
|
||||
Shape cache_tensor_map{3, 2, 1, 0};
|
||||
Shape update_tensor_map{3, 2, 4, 0};
|
||||
Shape valid_seq_len_tensor_map{3};
|
||||
Shape batch_index_tensor_map{-1};
|
||||
Shape seq_lqn_axis_tensor_map{-1};
|
||||
Shape new_max_seq_len_tensor_map{-1};
|
||||
Shape cur_max_seq_len_tensor_map{-1};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
|
||||
|
||||
Shape out_tensor_map{3, 2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
REGISTER(PromptKVCacheInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PROMPT_K_V_CACHE_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PROMPT_K_V_CACHE_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
/*
|
||||
* parallel class for PromptKVCache Primitive
|
||||
*/
|
||||
class PromptKVCacheInfo : public OperatorInfo {
|
||||
public:
|
||||
PromptKVCacheInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {}
|
||||
~PromptKVCacheInfo() override = default;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override { return {}; }
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status InferForwardCommunication() { return SUCCESS; }
|
||||
Status InferTensorMap() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
};
|
||||
using PromptKVCacheInfoPtr = std::shared_ptr<PromptKVCacheInfo>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PRELU_INFO_H_
|
|
@ -31,6 +31,28 @@ REG_OP(KVCacheMgr)
|
|||
.OUTPUT(past, TensorType({DT_FLOAT16}))
|
||||
.OP_END_FACTORY_REG(KVCacheMgr)
|
||||
|
||||
REG_OP(DecoderKvCache)
|
||||
.INPUT(cache, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
|
||||
.INPUT(update, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
|
||||
.INPUT(valid_seq_len, TensorType({DT_INT64}))
|
||||
.INPUT(batch_index, TensorType({DT_INT64}))
|
||||
.INPUT(seq_len_axis, TensorType({DT_INT64}))
|
||||
.INPUT(new_max_seq_len, TensorType({DT_INT64}))
|
||||
.INPUT(cur_max_seq_len, TensorType({DT_INT64}))
|
||||
.OUTPUT(out, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
|
||||
.OP_END_FACTORY_REG(DecoderKvCache)
|
||||
|
||||
REG_OP(PromptKvCache)
|
||||
.INPUT(cache, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
|
||||
.INPUT(update, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
|
||||
.INPUT(valid_seq_len, TensorType({DT_INT64}))
|
||||
.INPUT(batch_index, TensorType({DT_INT64}))
|
||||
.INPUT(seq_len_axis, TensorType({DT_INT64}))
|
||||
.INPUT(new_max_seq_len, TensorType({DT_INT64}))
|
||||
.INPUT(cur_max_seq_len, TensorType({DT_INT64}))
|
||||
.OUTPUT(out, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_FLOAT16, DT_FLOAT, DT_INT32}))
|
||||
.OP_END_FACTORY_REG(PromptKvCache)
|
||||
|
||||
REG_CUST_OP(NoRepeatNGram)
|
||||
.INPUT(state_seq, TensorType({DT_INT32}))
|
||||
.INPUT(log_probs, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
|
|
|
@ -26,6 +26,24 @@ ATTR_MAP(KVCacheMgr) = EMPTY_ATTR_MAP;
|
|||
OUTPUT_MAP(KVCacheMgr) = {{0, OUTPUT_DESC(past)}};
|
||||
REG_ADPT_DESC(KVCacheMgr, "KVCacheMgr", ADPT_DESC(KVCacheMgr))
|
||||
|
||||
// DecoderKVCache
|
||||
INPUT_MAP(DecoderKvCache) = {{1, INPUT_DESC(cache)}, {2, INPUT_DESC(update)},
|
||||
{3, INPUT_DESC(valid_seq_len)}, {4, INPUT_DESC(batch_index)},
|
||||
{5, INPUT_DESC(seq_len_axis)}, {6, INPUT_DESC(new_max_seq_len)},
|
||||
{7, INPUT_DESC(cur_max_seq_len)}};
|
||||
ATTR_MAP(DecoderKvCache) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(DecoderKvCache) = {{0, OUTPUT_DESC(out)}};
|
||||
REG_ADPT_DESC(DecoderKvCache, "DecoderKVCache", ADPT_DESC(DecoderKvCache))
|
||||
|
||||
// PromptKVCache
|
||||
INPUT_MAP(PromptKvCache) = {{1, INPUT_DESC(cache)}, {2, INPUT_DESC(update)},
|
||||
{3, INPUT_DESC(valid_seq_len)}, {4, INPUT_DESC(batch_index)},
|
||||
{5, INPUT_DESC(seq_len_axis)}, {6, INPUT_DESC(new_max_seq_len)},
|
||||
{7, INPUT_DESC(cur_max_seq_len)}};
|
||||
ATTR_MAP(PromptKvCache) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(PromptKvCache) = {{0, OUTPUT_DESC(out)}};
|
||||
REG_ADPT_DESC(PromptKvCache, "PromptKVCache", ADPT_DESC(PromptKvCache))
|
||||
|
||||
// FlashAttention
|
||||
INPUT_MAP(FlashAttention) = {
|
||||
{1, INPUT_DESC(q)}, {2, INPUT_DESC(k)}, {3, INPUT_DESC(v)}, {4, INPUT_DESC(attention_mask)}};
|
||||
|
|
|
@ -27,10 +27,16 @@
|
|||
DECLARE_OP_ADAPTER(KVCacheMgr)
|
||||
DECLARE_OP_USE_OUTPUT(KVCacheMgr)
|
||||
|
||||
DECLARE_OP_ADAPTER(DecoderKvCache)
|
||||
DECLARE_OP_USE_OUTPUT(DecoderKvCache)
|
||||
|
||||
DECLARE_OP_ADAPTER(FlashAttention)
|
||||
DECLARE_OP_USE_OUTPUT(FlashAttention)
|
||||
|
||||
DECLARE_OP_ADAPTER(FFN)
|
||||
DECLARE_OP_USE_OUTPUT(FFN)
|
||||
|
||||
DECLARE_OP_ADAPTER(PromptKvCache)
|
||||
DECLARE_OP_USE_OUTPUT(PromptKvCache)
|
||||
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_TRANSFORM_FUSION_OPS_DECLARE_H_
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
decoder_k_v_cache:
|
||||
description: |
|
||||
The DecoderKVCache is used for decoding the KVCache of transformer network.
|
||||
|
||||
Args:
|
||||
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
|
||||
When seq_len_axis is 2, cache tensor of shape
|
||||
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
|
||||
When seq_len_axis is 1, cache tensor of shape
|
||||
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
||||
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
||||
When seq_len_axis is 2, update tensor of shape
|
||||
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
||||
When seq_len_axis is 1, update tensor of shape
|
||||
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
||||
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
||||
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
||||
batch_index (Tensor): The batch_index tensor with data type of int64.
|
||||
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
||||
seq_len_axis (Tensor): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
||||
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
||||
New_max_seq_len tensor of shape :math:`(1)`.
|
||||
Indicate that user want to change the shape of cache tensor from
|
||||
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
||||
:math:
|
||||
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
||||
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
||||
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
||||
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
||||
|
||||
Outputs:
|
||||
With same data type and same shape as `cache` tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops.operations import _inner_ops
|
||||
>>> b = 4
|
||||
>>> h = 40
|
||||
>>> max_s = 1024
|
||||
>>> s = 1
|
||||
>>> d = 128
|
||||
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
||||
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
||||
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
||||
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
||||
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
||||
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
||||
>>> decoder_kv_cache = _inner_ops.DecoderKVCache()
|
||||
>>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, Tensor(2), new_max_seq_len, cur_max_seq_len)
|
||||
>>> print(cache)
|
|
@ -0,0 +1,22 @@
|
|||
#operator decoder_k_v_cache
|
||||
decoder_k_v_cache:
|
||||
args:
|
||||
cache:
|
||||
dtype: tensor
|
||||
update:
|
||||
dtype: tensor
|
||||
valid_seq_len:
|
||||
dtype: tensor
|
||||
batch_index:
|
||||
dtype: tensor
|
||||
seq_len_axis:
|
||||
dtype: tensor
|
||||
new_max_seq_len:
|
||||
dtype: tensor
|
||||
cur_max_seq_len:
|
||||
dtype: tensor
|
||||
labels:
|
||||
side_effect_mem: True
|
||||
returns:
|
||||
out:
|
||||
dtype: tensor
|
|
@ -0,0 +1,55 @@
|
|||
prompt_k_v_cache:
|
||||
description: |
|
||||
The PromptKVCache is used for prefill the KVCache of transformer network.
|
||||
|
||||
Args:
|
||||
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
|
||||
When seq_len_axis is 2, cache tensor of shape
|
||||
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
|
||||
When seq_len_axis is 1, cache tensor of shape
|
||||
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
||||
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
||||
When seq_len_axis is 2, update tensor of shape
|
||||
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
||||
When seq_len_axis is 1, update tensor of shape
|
||||
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
||||
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
||||
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
||||
batch_index (Tensor): The batch_index tensor with data type of int64.
|
||||
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
||||
seq_len_axis (Tensor): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
||||
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
||||
New_max_seq_len tensor of shape :math:`(1)`.
|
||||
Indicate that user want to change the shape of cache tensor from
|
||||
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
||||
:math:
|
||||
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
||||
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
||||
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
||||
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
||||
align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
|
||||
|
||||
|
||||
Outputs:
|
||||
With same data type and same shape as `cache` tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.operations import _inner_ops
|
||||
>>> b = 4
|
||||
>>> h = 40
|
||||
>>> max_s = 1024
|
||||
>>> s = 256
|
||||
>>> d = 128
|
||||
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
||||
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
||||
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
||||
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
||||
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
||||
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
||||
>>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
|
||||
>>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, Tensor(2), new_max_seq_len, cur_max_seq_len)
|
||||
>>> print(cache)
|
|
@ -0,0 +1,27 @@
|
|||
#operator prompt_k_v_cache
|
||||
prompt_k_v_cache:
|
||||
args:
|
||||
cache:
|
||||
dtype: tensor
|
||||
update:
|
||||
dtype: tensor
|
||||
valid_seq_len:
|
||||
dtype: tensor
|
||||
batch_index:
|
||||
dtype: tensor
|
||||
seq_len_axis:
|
||||
dtype: tensor
|
||||
new_max_seq_len:
|
||||
dtype: tensor
|
||||
cur_max_seq_len:
|
||||
dtype: tensor
|
||||
align_mode:
|
||||
dtype: int
|
||||
default: 0
|
||||
prim_init: True
|
||||
arg_handler: k_v_cache_align_mode_to_enum
|
||||
labels:
|
||||
side_effect_mem: True
|
||||
returns:
|
||||
out:
|
||||
dtype: tensor
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/ops_func_impl/decoder_k_v_cache.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
BaseShapePtr DecoderKVCacheFuncImpl::InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return input_args[0]->GetShape()->Clone();
|
||||
}
|
||||
|
||||
TypePtr DecoderKVCacheFuncImpl::InferType(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return input_args[0]->GetType()->Clone();
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_DECODER_K_V_CACHE_H_
|
||||
#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_DECODER_K_V_CACHE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "ops/ops_func_impl/op_func_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class MIND_API DecoderKVCacheFuncImpl : public OpFuncImpl {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_DECODER_K_V_CACHE_H_
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/ops_func_impl/prompt_k_v_cache.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
BaseShapePtr PromptKVCacheFuncImpl::InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return input_args[0]->GetShape()->Clone();
|
||||
}
|
||||
|
||||
TypePtr PromptKVCacheFuncImpl::InferType(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return input_args[0]->GetType()->Clone();
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_PROMPT_K_V_CACHE_H_
|
||||
#define MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_PROMPT_K_V_CACHE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "ops/ops_func_impl/op_func_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class MIND_API PromptKVCacheFuncImpl : public OpFuncImpl {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_OPS_FUNC_IMPL_PROMPT_K_V_CACHE_H_
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Test DecoderKVCache plugin custom ops.
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore_lite as mslite
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import export
|
||||
from mindspore.ops.operations._inner_ops import DecoderKVCache
|
||||
|
||||
|
||||
class DecoderKVCacheNet(nn.Cell):
|
||||
"""
|
||||
DecoderKVCacheNet.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = ops.Add()
|
||||
self.sub = ops.Sub()
|
||||
self.decoder_k_v_cache = DecoderKVCache()
|
||||
self.seq_len_axis = [2, 0, 0, 0]
|
||||
|
||||
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
out = self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, self.seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len)
|
||||
add_out = self.add(cache, 1)
|
||||
sub_out = self.sub(add_out, 1)
|
||||
return sub_out
|
||||
|
||||
|
||||
def np_inference(cache, update, valid_seq_len):
|
||||
"""
|
||||
np_inference
|
||||
"""
|
||||
ans = cache.copy()
|
||||
for b in range(cache.shape[0]):
|
||||
bs_idx = valid_seq_len[b]
|
||||
ans[b, :, bs_idx, :] = update[b, :, 0, :]
|
||||
return ans
|
||||
|
||||
|
||||
def create_numpy_inputs():
|
||||
"""
|
||||
create inputs
|
||||
"""
|
||||
cache = np.random.rand(4, 2, 4, 32).astype(np.float16)
|
||||
update = np.random.rand(4, 2, 1, 32).astype(np.float16)
|
||||
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
|
||||
batch_index = np.array([1]).astype(np.int64)
|
||||
seq_len_axis = np.array([2]).astype(np.int64)
|
||||
new_max_seq_len = np.array([4]).astype(np.int64)
|
||||
cur_max_seq_len = np.array([1]).astype(np.int64)
|
||||
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def create_ms_inputs():
|
||||
"""
|
||||
create inputs
|
||||
"""
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
|
||||
ms_cache = Tensor(cache)
|
||||
ms_update = Tensor(update)
|
||||
ms_valid_seq_len = Tensor(valid_seq_len)
|
||||
ms_batch_index = Tensor(batch_index)
|
||||
ms_seq_len_axis = Tensor(seq_len_axis)
|
||||
ms_new_max_seq_len = Tensor(new_max_seq_len)
|
||||
ms_cur_max_seq_len = Tensor(cur_max_seq_len)
|
||||
return (ms_cache, ms_update, ms_valid_seq_len, ms_batch_index, ms_seq_len_axis,
|
||||
ms_new_max_seq_len, ms_cur_max_seq_len)
|
||||
|
||||
|
||||
def create_np_inputs(cache, update, valid_seq_len):
|
||||
"""
|
||||
create_np_inputs
|
||||
"""
|
||||
return (cache, update, valid_seq_len)
|
||||
|
||||
|
||||
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
"""
|
||||
create_lite_inputs
|
||||
"""
|
||||
cache = mslite.Tensor(cache)
|
||||
update = mslite.Tensor(update)
|
||||
valid_seq_len = mslite.Tensor(valid_seq_len)
|
||||
batch_index = mslite.Tensor(batch_index)
|
||||
seq_len_axis = mslite.Tensor(seq_len_axis)
|
||||
new_max_seq_len = mslite.Tensor(new_max_seq_len)
|
||||
cur_max_seq_len = mslite.Tensor(cur_max_seq_len)
|
||||
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def inference_decoder_k_v_cache(mindir_model):
|
||||
"""
|
||||
inference model
|
||||
"""
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
|
||||
|
||||
lite_ctx1 = mslite.Context()
|
||||
lite_ctx1.target = ["ascend"]
|
||||
lite_ctx1.ascend.device_id = 1
|
||||
lite_ctx1.ascend.provider = "ge"
|
||||
|
||||
model = mslite.Model()
|
||||
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
|
||||
|
||||
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len))
|
||||
mslite_output = model.predict(input_lists)
|
||||
|
||||
np_cache, np_update, np_valid_seq_len = create_np_inputs(cache, update, valid_seq_len)
|
||||
expect_output = np_inference(np_cache, np_update, np_valid_seq_len)
|
||||
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
|
||||
|
||||
|
||||
def export_decoder_k_v_cache_model():
|
||||
"""
|
||||
export model
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
|
||||
|
||||
net = DecoderKVCacheNet()
|
||||
file_name = "decoder_k_v_cache_primitive"
|
||||
|
||||
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
|
||||
file_name=file_name, file_format='MINDIR')
|
||||
model_name = file_name + ".mindir"
|
||||
assert os.path.exists(model_name)
|
||||
return model_name
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_path = export_decoder_k_v_cache_model()
|
||||
print("decoder_k_v_cache st : export success path: ", model_path)
|
||||
|
||||
inference_decoder_k_v_cache(model_path)
|
||||
print("decoder_k_v_cache st : inference success.")
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Test PromptKVCache plugin custom ops.
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore_lite as mslite
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor, context, Parameter
|
||||
from mindspore.train.serialization import export
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.operations._inner_ops import PromptKVCache
|
||||
|
||||
|
||||
class PromptKVCacheNet(nn.Cell):
|
||||
"""
|
||||
PromptKVCacheNet.
|
||||
"""
|
||||
def __init__(self, padding_mode):
|
||||
super().__init__()
|
||||
self.sub = ops.Sub()
|
||||
self.add = ops.Add()
|
||||
self.concat_dim2 = ops.Concat(axis=2)
|
||||
self.prompt_k_v_cache = PromptKVCache(padding_mode)
|
||||
self.pad_update_zero_tensor = Parameter(ops.zeros((1, 2, 3, 32), mstype.float16))
|
||||
|
||||
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
update_pad = self.concat_dim2((update, self.pad_update_zero_tensor))
|
||||
out = self.prompt_k_v_cache(cache, update_pad, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len)
|
||||
add_out = self.add(cache, 1)
|
||||
sub_out = self.sub(add_out, 1)
|
||||
return sub_out
|
||||
|
||||
|
||||
def np_inference(cache, update, batch_index):
|
||||
"""
|
||||
np_inference
|
||||
"""
|
||||
zeros_ans = np.zeros(cache.shape, cache.dtype)
|
||||
us = update.shape[2]
|
||||
cache[batch_index, :] = zeros_ans[batch_index, :]
|
||||
cache[batch_index, :, 0:us, :] = update
|
||||
return cache
|
||||
|
||||
|
||||
def create_numpy_inputs():
|
||||
"""
|
||||
create inputs
|
||||
"""
|
||||
cache = np.zeros([4, 2, 4, 32]).astype(np.float16)
|
||||
cache = cache + 9
|
||||
update = np.ones([1, 2, 1, 32]).astype(np.float16)
|
||||
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
|
||||
batch_index = np.array([1]).astype(np.int64)
|
||||
seq_len_axis = np.array([2]).astype(np.int64)
|
||||
new_max_seq_len = np.array([4]).astype(np.int64)
|
||||
cur_max_seq_len = np.array([1]).astype(np.int64)
|
||||
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def create_ms_inputs():
|
||||
"""
|
||||
create inputs
|
||||
"""
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
|
||||
ms_cache = Tensor(cache)
|
||||
ms_update = Tensor(update)
|
||||
ms_valid_seq_len = Tensor(valid_seq_len)
|
||||
ms_batch_index = Tensor(batch_index)
|
||||
ms_seq_len_axis = Tensor(seq_len_axis)
|
||||
ms_new_max_seq_len = Tensor(new_max_seq_len)
|
||||
ms_cur_max_seq_len = Tensor(cur_max_seq_len)
|
||||
return (ms_cache, ms_update, ms_valid_seq_len, ms_batch_index, ms_seq_len_axis,
|
||||
ms_new_max_seq_len, ms_cur_max_seq_len)
|
||||
|
||||
|
||||
def create_np_inputs(cache, update, batch_index):
|
||||
"""
|
||||
create_np_inputs
|
||||
"""
|
||||
return (cache, update, batch_index)
|
||||
|
||||
|
||||
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
"""
|
||||
create_lite_inputs
|
||||
"""
|
||||
cache = mslite.Tensor(cache)
|
||||
update = mslite.Tensor(update)
|
||||
valid_seq_len = mslite.Tensor(valid_seq_len)
|
||||
batch_index = mslite.Tensor(batch_index)
|
||||
seq_len_axis = mslite.Tensor(seq_len_axis)
|
||||
new_max_seq_len = mslite.Tensor(new_max_seq_len)
|
||||
cur_max_seq_len = mslite.Tensor(cur_max_seq_len)
|
||||
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def inference_prompt_k_v_cache(mindir_model):
|
||||
"""
|
||||
inference model
|
||||
"""
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
|
||||
|
||||
lite_ctx1 = mslite.Context()
|
||||
lite_ctx1.target = ["ascend"]
|
||||
lite_ctx1.ascend.device_id = 1
|
||||
lite_ctx1.ascend.provider = "ge"
|
||||
|
||||
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len))
|
||||
model = mslite.Model()
|
||||
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
|
||||
|
||||
mslite_output = model.predict(input_lists)
|
||||
|
||||
np_cache, np_update, batch_index = create_np_inputs(cache, update, batch_index)
|
||||
expect_output = np_inference(np_cache, np_update, batch_index)
|
||||
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
|
||||
|
||||
|
||||
def export_prompt_k_v_cache_model():
|
||||
"""
|
||||
export model
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
|
||||
|
||||
net = PromptKVCacheNet("right")
|
||||
file_name = "prompt_k_v_cache_primitive"
|
||||
|
||||
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
|
||||
file_name=file_name, file_format='MINDIR')
|
||||
model_name = file_name + ".mindir"
|
||||
assert os.path.exists(model_name)
|
||||
return model_name
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_path = export_prompt_k_v_cache_model()
|
||||
print("prompt_k_v_cache st : export success path: ", model_path)
|
||||
|
||||
inference_prompt_k_v_cache(model_path)
|
||||
print("prompt_k_v_cache st : inference success.")
|
|
@ -235,3 +235,7 @@ coordinate_transformation_mode:
|
|||
ALIGN_CORNERS: 0
|
||||
HALF_PIXEL: 1
|
||||
|
||||
# enum KVCacheAlignMode
|
||||
k_v_cache_align_mode:
|
||||
RIGHT: 0
|
||||
LEFT: 1
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.ops.auto_generate import DecoderKVCache
|
||||
from parallel.utils.utils import ParallelValidator, compile_net
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
class DecoderKVCacheNet(Cell):
|
||||
def __init__(self, strategy):
|
||||
super(DecoderKVCacheNet, self).__init__()
|
||||
self.decoder_k_v_cache = DecoderKVCache().shard(strategy)
|
||||
|
||||
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
return self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def test_decoder_k_v_cache_net():
|
||||
"""
|
||||
Feature: test decoder_k_v_cache auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: shape is as expected.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 1, 1, 1,), (4, 1, 1, 1,), (4,), (1,), (1,), (1,), (1,))
|
||||
net = DecoderKVCacheNet(strategy)
|
||||
cache = Parameter(Tensor(np.ones([4, 40, 1024, 128]), dtype=ms.float16), "cache")
|
||||
update = Parameter(Tensor(np.ones([4, 40, 1, 128]), dtype=ms.float16), "update")
|
||||
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
|
||||
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
|
||||
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
|
||||
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
|
||||
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
|
||||
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
|
||||
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
|
||||
assert validator.check_parameter_shape('valid_seq_len', [1])
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.ops.auto_generate import PromptKVCache
|
||||
from parallel.utils.utils import ParallelValidator, compile_net
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
class PromptKVCacheNet(Cell):
|
||||
def __init__(self, padding_mode, strategy):
|
||||
super(PromptKVCacheNet, self).__init__()
|
||||
self.prompt_k_v_cache = PromptKVCache(padding_mode).shard(strategy)
|
||||
|
||||
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
return self.prompt_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def test_prompt_k_v_cache_net():
|
||||
"""
|
||||
Feature: test prompt_k_v_cache auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: shape is as expected.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 1, 1, 1,), (4, 1, 1, 1,), (4,), (1,), (1,), (1,), (1,))
|
||||
padding_mode = "right"
|
||||
net = PromptKVCacheNet(padding_mode, strategy)
|
||||
cache = Parameter(Tensor(np.ones([4, 40, 1024, 128]), dtype=ms.float16), "cache")
|
||||
update = Parameter(Tensor(np.ones([4, 40, 1, 128]), dtype=ms.float16), "update")
|
||||
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
|
||||
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
|
||||
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
|
||||
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
|
||||
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
|
||||
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
|
||||
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
|
||||
assert validator.check_parameter_shape('valid_seq_len', [1])
|
Loading…
Reference in New Issue