add PromptKVCache and DecoderKVCache ascend op
inplace update for ge fix bug of prompt kv cache fix buf of kernel sync problem prompt kvcache add multi batch index prompt kvcache and decoder kvcache support multi data type decoder kvcache jump -1 seq_len prompt kvcache jump -1 batch index fix bug of decoder_kv_cache not support dig batch size prompt_kv_cache support valid_seq_len prompt_kv_cache and decoder_kv_cache support BSD format make prompt_kv_cache and decoder_kv_cache support BSD format
This commit is contained in:
parent
2c6ff963f0
commit
58a648eaba
|
@ -98,6 +98,10 @@
|
|||
"mindspore/mindspore/lite/src/extendrt/cxx_api/model/model_impl.cc" "whitespace/parens"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/gemm_mask_avx512/" "runtime/int"
|
||||
"mindspore/mindspore/lite/python/src/lite_infer_pybind.cc" "runtime/references"
|
||||
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/op_host/decoder_kv_cache.cpp" "runtime/references"
|
||||
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/op_host/platform_ascendc.h" "runtime/references"
|
||||
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/op_host/prompt_kv_cache.cpp" "runtime/references"
|
||||
|
||||
# ascend samples
|
||||
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "build/include_subdir"
|
||||
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "build/include_subdir"
|
||||
|
|
|
@ -20,28 +20,107 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// DecoderKVCache seven inputs
|
||||
// cache: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
// update: (batch_size, num_head, update_seq_len, hidden_size)
|
||||
// ===== Case 1:
|
||||
// cache and update shape rank is 4.
|
||||
// cache: (batch, num_head, max_seq_len, hidden_size)
|
||||
// update: (batch, num_head, update_seq_len, hidden_size)
|
||||
// valid_seq_len: (batch)
|
||||
// batch_index: (1)
|
||||
// seq_len_axis: (1)
|
||||
// new_max_seq_len: (1)
|
||||
// cur_max_seq_len: (1)
|
||||
// ------------------------------
|
||||
// output: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
// output: (batch, num_head, max_seq_len, hidden_size)
|
||||
|
||||
// split strategy
|
||||
// batch_size is able to split.
|
||||
// batch is able to split.
|
||||
// max_seq_len, update_seq_len is not able to split.
|
||||
// num_head is able to split.
|
||||
// hidden_size is able to split.
|
||||
|
||||
// ===== Case 2:
|
||||
// cache and update shape rank is 3.
|
||||
// cache: (batch, max_seq_len, hidden_size)
|
||||
// update: (batch, update_seq_len, hidden_size)
|
||||
// valid_seq_len: (batch)
|
||||
// batch_index: (1)
|
||||
// seq_len_axis: (1)
|
||||
// new_max_seq_len: (1)
|
||||
// cur_max_seq_len: (1)
|
||||
// ------------------------------
|
||||
// output: (batch, max_seq_len, hidden_size)
|
||||
|
||||
// split strategy
|
||||
// batch is able to split.
|
||||
// max_seq_len, update_seq_len is not able to split.
|
||||
// hidden_size is able to split.
|
||||
|
||||
Status DecoderKVCacheInfo::CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
|
||||
if (strategy_cache.at(1) != 1 || strategy_update.at(1) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(1)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(1);
|
||||
return FAILED;
|
||||
}
|
||||
if (strategy_cache.at(2) != strategy_update.at(2)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DecoderKVCacheInfo::CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
|
||||
// num_head must be the same strategy.
|
||||
if (strategy_cache.at(1) != strategy_cache.at(1)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(2)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(2);
|
||||
return FAILED;
|
||||
}
|
||||
// hidden_size must be the same strategy.
|
||||
if (strategy_cache.at(3) != strategy_update.at(3)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DecoderKVCacheInfo::SetDims(const StrategyPtr &strategy) {
|
||||
auto input_strategys = strategy->GetInputDim();
|
||||
auto strategy_cache = input_strategys.at(0);
|
||||
|
||||
const size_t input_dims4 = 4;
|
||||
const size_t input_dims3 = 3;
|
||||
if (strategy_cache.size() == input_dims4) {
|
||||
is_input_dims_4_ = true;
|
||||
} else if (strategy_cache.size() == input_dims3) {
|
||||
is_input_dims_4_ = false;
|
||||
} else {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DecoderKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
auto input_strategys = strategy->GetInputDim();
|
||||
if (SetDims(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto input_strategys = strategy->GetInputDim();
|
||||
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
|
||||
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
|
||||
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
|
||||
|
@ -74,69 +153,77 @@ Status DecoderKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(2)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(2);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// batch_size must be the same strategy.
|
||||
// batch must be the same strategy.
|
||||
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update
|
||||
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// num_head must be the same strategy.
|
||||
if (strategy_cache.at(1) != strategy_cache.at(1)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
if (is_input_dims_4_) {
|
||||
return CheckStrategy4Dims(strategy_cache, strategy_update);
|
||||
}
|
||||
|
||||
// hidden_size must be the same strategy.
|
||||
if (strategy_cache.at(3) != strategy_update.at(3)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
return CheckStrategy3Dims(strategy_cache, strategy_update);
|
||||
}
|
||||
|
||||
Status DecoderKVCacheInfo::InferDevMatrixShape() {
|
||||
auto input_strategys = strategy()->GetInputDim();
|
||||
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
|
||||
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
|
||||
auto cache = input_strategys.at(0); // batch num_head max_seq_len hidden_size
|
||||
auto update = input_strategys.at(1); // batch num_head update_seq_len hidden_size
|
||||
|
||||
if (is_input_dims_4_) {
|
||||
// cache shape (batch num_head max_seq_len hidden_size)
|
||||
// update shape (batch num_head update_seq_len hidden_size)
|
||||
// update_seq_len batch num_head max_seq_len hidden_size
|
||||
// 4 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
|
||||
} else {
|
||||
// cache shape (batch max_seq_len hidden_size)
|
||||
// update shape (batch update_seq_len hidden_size)
|
||||
// update_seq_len batch max_seq_len hidden_size
|
||||
// 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2)};
|
||||
}
|
||||
|
||||
// update_seq_len batch_size num_head max_seq_len hidden_size
|
||||
// 4 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status DecoderKVCacheInfo::InferTensorMap() {
|
||||
Shape cache_tensor_map{3, 2, 1, 0};
|
||||
Shape update_tensor_map{3, 2, 4, 0};
|
||||
Shape valid_seq_len_tensor_map{3};
|
||||
if (is_input_dims_4_) {
|
||||
Shape cache_tensor_map{3, 2, 1, 0};
|
||||
Shape update_tensor_map{3, 2, 4, 0};
|
||||
Shape valid_seq_len_tensor_map{3};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
} else {
|
||||
Shape cache_tensor_map{2, 1, 0};
|
||||
Shape update_tensor_map{2, 3, 0};
|
||||
Shape valid_seq_len_tensor_map{2};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
}
|
||||
|
||||
Shape batch_index_tensor_map{-1};
|
||||
Shape seq_lqn_axis_tensor_map{-1};
|
||||
Shape new_max_seq_len_tensor_map{-1};
|
||||
Shape cur_max_seq_len_tensor_map{-1};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
|
||||
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
|
||||
|
||||
Shape out_tensor_map{3, 2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
if (is_input_dims_4_) {
|
||||
Shape out_tensor_map{3, 2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
} else {
|
||||
Shape out_tensor_map{2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
REGISTER(DecoderKVCacheInfo);
|
||||
|
|
|
@ -44,6 +44,12 @@ class DecoderKVCacheInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() { return SUCCESS; }
|
||||
Status InferTensorMap() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status SetDims(const StrategyPtr &strategy);
|
||||
Status CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
|
||||
Status CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
|
||||
|
||||
private:
|
||||
bool is_input_dims_4_ = true;
|
||||
};
|
||||
using DecoderKVCacheInfoPtr = std::shared_ptr<DecoderKVCacheInfo>;
|
||||
} // namespace parallel
|
||||
|
|
|
@ -20,29 +20,107 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// PromptKVCache seven inputs
|
||||
// cache: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
// update: (batch_size, num_head, update_seq_len, hidden_size)
|
||||
// ===== Case 1:
|
||||
// cache and update shape rank is 4.
|
||||
// cache: (batch, num_head, max_seq_len, hidden_size)
|
||||
// update: (batch, num_head, update_seq_len, hidden_size)
|
||||
// valid_seq_len: (batch)
|
||||
// batch_index: (1)
|
||||
// seq_len_axis: (1)
|
||||
// new_max_seq_len: (1)
|
||||
// cur_max_seq_len: (1)
|
||||
// ------------------------------
|
||||
// output: (batch_size, num_head, max_seq_len, hidden_size)
|
||||
// output: (batch, num_head, max_seq_len, hidden_size)
|
||||
|
||||
// split strategy
|
||||
// batch_size is able to split.
|
||||
// batch is able to split.
|
||||
// max_seq_len, update_seq_len is not able to split.
|
||||
// num_head is able to split.
|
||||
// hidden_size is able to split.
|
||||
|
||||
// ====== Case 2:
|
||||
// cache and update shape rank is 3.
|
||||
// cache: (batch, max_seq_len, hidden_size)
|
||||
// update: (batch, update_seq_len, hidden_size)
|
||||
// valid_seq_len: (batch)
|
||||
// batch_index: (1)
|
||||
// seq_len_axis: (1)
|
||||
// new_max_seq_len: (1)
|
||||
// cur_max_seq_len: (1)
|
||||
// ------------------------------
|
||||
// output: (batch, max_seq_len, hidden_size)
|
||||
|
||||
// split strategy
|
||||
// batch is able to split.
|
||||
// max_seq_len, update_seq_len is not able to split.
|
||||
// hidden_size is able to split.
|
||||
|
||||
Status PromptKVCacheInfo::CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
|
||||
if (strategy_cache.at(1) != 1 || strategy_update.at(1) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(1)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(1);
|
||||
return FAILED;
|
||||
}
|
||||
if (strategy_cache.at(2) != strategy_update.at(2)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status PromptKVCacheInfo::CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
|
||||
// num_head must be the same strategy.
|
||||
if (strategy_cache.at(1) != strategy_cache.at(1)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(2)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(2);
|
||||
return FAILED;
|
||||
}
|
||||
// hidden_size must be the same strategy.
|
||||
if (strategy_cache.at(3) != strategy_update.at(3)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status PromptKVCacheInfo::SetDims(const StrategyPtr &strategy) {
|
||||
auto input_strategys = strategy->GetInputDim();
|
||||
auto strategy_cache = input_strategys.at(0);
|
||||
|
||||
const size_t input_dims4 = 4;
|
||||
const size_t input_dims3 = 3;
|
||||
if (strategy_cache.size() == input_dims4) {
|
||||
is_input_dims_4_ = true;
|
||||
} else if (strategy_cache.size() == input_dims3) {
|
||||
is_input_dims_4_ = false;
|
||||
} else {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status PromptKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (SetDims(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto input_strategys = strategy->GetInputDim();
|
||||
|
||||
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
|
||||
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
|
||||
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
|
||||
|
@ -75,69 +153,78 @@ Status PromptKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
|
||||
<< " cache's seq_len strategy: " << strategy_cache.at(2)
|
||||
<< "; update's seq_len strategy: " << strategy_update.at(2);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// batch_size must be the same strategy.
|
||||
// batch must be the same strategy.
|
||||
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update
|
||||
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// num_head must be the same strategy.
|
||||
if (strategy_cache.at(1) != strategy_cache.at(1)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
if (is_input_dims_4_) {
|
||||
return CheckStrategy4Dims(strategy_cache, strategy_update);
|
||||
}
|
||||
|
||||
// hidden_size must be the same strategy.
|
||||
if (strategy_cache.at(3) != strategy_update.at(3)) {
|
||||
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
|
||||
<< " strategy_cache's strategy: " << strategy_cache
|
||||
<< ", strategy_update's strategy: " << strategy_update;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
return CheckStrategy3Dims(strategy_cache, strategy_update);
|
||||
}
|
||||
|
||||
Status PromptKVCacheInfo::InferDevMatrixShape() {
|
||||
auto input_strategys = strategy()->GetInputDim();
|
||||
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
|
||||
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
|
||||
|
||||
// update_seq_len batch_size num_head max_seq_len hidden_size
|
||||
// 4 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
|
||||
auto cache = input_strategys.at(0);
|
||||
auto update = input_strategys.at(1);
|
||||
|
||||
if (is_input_dims_4_) {
|
||||
// cache shape (batch num_head max_seq_len hidden_size)
|
||||
// update shape (batch num_head update_seq_len hidden_size)
|
||||
// update_seq_len batch num_head max_seq_len hidden_size
|
||||
// 4 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
|
||||
} else {
|
||||
// cache shape (batch max_seq_len hidden_size)
|
||||
// update shape (batch update_seq_len hidden_size)
|
||||
// update_seq_len batch max_seq_len hidden_size
|
||||
// 3 2 1 0
|
||||
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2)};
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status PromptKVCacheInfo::InferTensorMap() {
|
||||
Shape cache_tensor_map{3, 2, 1, 0};
|
||||
Shape update_tensor_map{3, 2, 4, 0};
|
||||
Shape valid_seq_len_tensor_map{3};
|
||||
if (is_input_dims_4_) {
|
||||
Shape cache_tensor_map{3, 2, 1, 0};
|
||||
Shape update_tensor_map{3, 2, 4, 0};
|
||||
Shape valid_seq_len_tensor_map{3};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
} else {
|
||||
Shape cache_tensor_map{2, 1, 0};
|
||||
Shape update_tensor_map{2, 3, 0};
|
||||
Shape valid_seq_len_tensor_map{2};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
}
|
||||
|
||||
Shape batch_index_tensor_map{-1};
|
||||
Shape seq_lqn_axis_tensor_map{-1};
|
||||
Shape new_max_seq_len_tensor_map{-1};
|
||||
Shape cur_max_seq_len_tensor_map{-1};
|
||||
inputs_tensor_map_.emplace_back(cache_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(update_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
|
||||
|
||||
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
|
||||
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
|
||||
|
||||
Shape out_tensor_map{3, 2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
if (is_input_dims_4_) {
|
||||
Shape out_tensor_map{3, 2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
} else {
|
||||
Shape out_tensor_map{2, 1, 0};
|
||||
outputs_tensor_map_.emplace_back(out_tensor_map);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
REGISTER(PromptKVCacheInfo);
|
||||
|
|
|
@ -45,6 +45,12 @@ class PromptKVCacheInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() { return SUCCESS; }
|
||||
Status InferTensorMap() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status SetDims(const StrategyPtr &strategy);
|
||||
Status CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
|
||||
Status CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
|
||||
|
||||
private:
|
||||
bool is_input_dims_4_ = true;
|
||||
};
|
||||
using PromptKVCacheInfoPtr = std::shared_ptr<PromptKVCacheInfo>;
|
||||
} // namespace parallel
|
||||
|
|
|
@ -613,7 +613,8 @@ bool DfGraphConvertor::NodeInputKeepUpdate(const FuncGraphManagerPtr &manager, c
|
|||
return true;
|
||||
}
|
||||
const auto &node_users = manager->node_users();
|
||||
std::vector<PrimitivePtr> vec{prim::kPrimAssign, prim::kPrimKVCacheMgr, prim::kPrimScatterUpdate};
|
||||
std::vector<PrimitivePtr> vec{prim::kPrimAssign, prim::kPrimKVCacheMgr, prim::kPrimScatterUpdate,
|
||||
prim::kPrimPromptKVCache, prim::kPrimDecoderKVCache};
|
||||
auto user_it = node_users.find(node);
|
||||
if (user_it != node_users.end()) {
|
||||
auto &users = user_it->second;
|
||||
|
|
|
@ -22,36 +22,51 @@ import mindspore.nn as nn
|
|||
import mindspore.ops as ops
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import export
|
||||
from mindspore.ops.operations._inner_ops import DecoderKVCache
|
||||
from mindspore.ops.auto_generate.gen_inner_ops_def import DecoderKVCache
|
||||
|
||||
|
||||
b = 26
|
||||
h = 40
|
||||
s = 32
|
||||
d = 128
|
||||
us = 1
|
||||
ps = s - us
|
||||
|
||||
is_4d = True
|
||||
|
||||
|
||||
class DecoderKVCacheNet(nn.Cell):
|
||||
"""
|
||||
DecoderKVCacheNet.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = ops.Add()
|
||||
self.sub = ops.Sub()
|
||||
self.decoder_k_v_cache = DecoderKVCache()
|
||||
self.seq_len_axis = [2, 0, 0, 0]
|
||||
|
||||
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
out = self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, self.seq_len_axis,
|
||||
out = self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len)
|
||||
add_out = self.add(cache, 1)
|
||||
sub_out = self.sub(add_out, 1)
|
||||
return sub_out
|
||||
|
||||
|
||||
def np_inference(cache, update, valid_seq_len):
|
||||
def np_inference(cache, update, valid_seq_len):
|
||||
"""
|
||||
np_inference
|
||||
"""
|
||||
ans = cache.copy()
|
||||
for b in range(cache.shape[0]):
|
||||
bs_idx = valid_seq_len[b]
|
||||
ans[b, :, bs_idx, :] = update[b, :, 0, :]
|
||||
for b_idx in range(cache.shape[0]):
|
||||
s_idx = valid_seq_len[b_idx]
|
||||
if s_idx < 0:
|
||||
continue
|
||||
if is_4d:
|
||||
ans[b_idx, :, s_idx, :] = update[b_idx, :, 0, :]
|
||||
else:
|
||||
ans[b_idx, s_idx, :] = update[b_idx, 0, :]
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -59,13 +74,20 @@ def create_numpy_inputs():
|
|||
"""
|
||||
create inputs
|
||||
"""
|
||||
cache = np.random.rand(4, 2, 4, 32).astype(np.float16)
|
||||
update = np.random.rand(4, 2, 1, 32).astype(np.float16)
|
||||
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
|
||||
global is_4d
|
||||
if is_4d:
|
||||
cache_shape = (b, h, s, d)
|
||||
update_shape = (b, h, us, d)
|
||||
else:
|
||||
cache_shape = (b, s, h * d)
|
||||
update_shape = (b, us, h * d)
|
||||
cache = np.random.rand(*cache_shape).astype(np.float16)
|
||||
update = np.random.rand(*update_shape).astype(np.float16)
|
||||
valid_seq_len = np.random.randint(-1, s, size=b).astype(np.int64)
|
||||
batch_index = np.array([1]).astype(np.int64)
|
||||
seq_len_axis = np.array([2]).astype(np.int64)
|
||||
new_max_seq_len = np.array([4]).astype(np.int64)
|
||||
cur_max_seq_len = np.array([1]).astype(np.int64)
|
||||
new_max_seq_len = np.array([s]).astype(np.int64)
|
||||
cur_max_seq_len = np.array([s]).astype(np.int64)
|
||||
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
|
@ -85,13 +107,6 @@ def create_ms_inputs():
|
|||
ms_new_max_seq_len, ms_cur_max_seq_len)
|
||||
|
||||
|
||||
def create_np_inputs(cache, update, valid_seq_len):
|
||||
"""
|
||||
create_np_inputs
|
||||
"""
|
||||
return (cache, update, valid_seq_len)
|
||||
|
||||
|
||||
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
"""
|
||||
create_lite_inputs
|
||||
|
@ -110,23 +125,26 @@ def inference_decoder_k_v_cache(mindir_model):
|
|||
"""
|
||||
inference model
|
||||
"""
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
|
||||
|
||||
lite_ctx1 = mslite.Context()
|
||||
lite_ctx1.target = ["ascend"]
|
||||
lite_ctx1.ascend.device_id = 1
|
||||
lite_ctx1.ascend.device_id = 0
|
||||
lite_ctx1.ascend.provider = "ge"
|
||||
|
||||
model = mslite.Model()
|
||||
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
|
||||
model.build_from_file(
|
||||
mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
|
||||
|
||||
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len))
|
||||
mslite_output = model.predict(input_lists)
|
||||
|
||||
np_cache, np_update, np_valid_seq_len = create_np_inputs(cache, update, valid_seq_len)
|
||||
expect_output = np_inference(np_cache, np_update, np_valid_seq_len)
|
||||
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
|
||||
for i in range(50):
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, \
|
||||
cur_max_seq_len = create_numpy_inputs()
|
||||
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len))
|
||||
mslite_output = model.predict(input_lists)
|
||||
expect_output = np_inference(cache, update, valid_seq_len)
|
||||
assert np.allclose(
|
||||
mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
|
||||
print(f"decoder_k_v_cache st {i} times: inference success.")
|
||||
|
||||
|
||||
def export_decoder_k_v_cache_model():
|
||||
|
@ -137,7 +155,10 @@ def export_decoder_k_v_cache_model():
|
|||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
|
||||
|
||||
net = DecoderKVCacheNet()
|
||||
file_name = "decoder_k_v_cache_primitive"
|
||||
if is_4d:
|
||||
file_name = "decoder_k_v_cache_primitive_4d"
|
||||
else:
|
||||
file_name = "decoder_k_v_cache_primitive_3d"
|
||||
|
||||
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
|
||||
file_name=file_name, file_format='MINDIR')
|
||||
|
@ -146,9 +167,25 @@ def export_decoder_k_v_cache_model():
|
|||
return model_name
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def test_4d():
|
||||
global is_4d
|
||||
is_4d = True
|
||||
model_path = export_decoder_k_v_cache_model()
|
||||
print("decoder_k_v_cache st : export success path: ", model_path)
|
||||
print("decoder_k_v_cache 4d st : export success path: ", model_path)
|
||||
|
||||
inference_decoder_k_v_cache(model_path)
|
||||
print("decoder_k_v_cache st : inference success.")
|
||||
print(f"decoder_k_v_cache 4d st : inference end.")
|
||||
|
||||
def test_3d():
|
||||
global is_4d
|
||||
is_4d = False
|
||||
model_path = export_decoder_k_v_cache_model()
|
||||
print("decoder_k_v_cache 3d st : export success path: ", model_path)
|
||||
|
||||
inference_decoder_k_v_cache(model_path)
|
||||
print(f"decoder_k_v_cache 3d st : inference end.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_3d()
|
||||
test_4d()
|
||||
|
|
|
@ -20,41 +20,65 @@ import numpy as np
|
|||
import mindspore_lite as mslite
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor, context, Parameter
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import export
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.operations._inner_ops import PromptKVCache
|
||||
from mindspore.ops.auto_generate.gen_inner_ops_def import PromptKVCache
|
||||
|
||||
b = 40
|
||||
h = 4
|
||||
s = 1024
|
||||
d = 32
|
||||
ub = 40
|
||||
us = 512
|
||||
ps = s - us
|
||||
|
||||
is_4d = True
|
||||
|
||||
|
||||
class PromptKVCacheNet(nn.Cell):
|
||||
"""
|
||||
PromptKVCacheNet.
|
||||
"""
|
||||
|
||||
def __init__(self, padding_mode):
|
||||
super().__init__()
|
||||
self.sub = ops.Sub()
|
||||
self.add = ops.Add()
|
||||
self.concat_dim2 = ops.Concat(axis=2)
|
||||
self.prompt_k_v_cache = PromptKVCache(padding_mode)
|
||||
self.pad_update_zero_tensor = Parameter(ops.zeros((1, 2, 3, 32), mstype.float16))
|
||||
|
||||
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
update_pad = self.concat_dim2((update, self.pad_update_zero_tensor))
|
||||
out = self.prompt_k_v_cache(cache, update_pad, valid_seq_len, batch_index, seq_len_axis,
|
||||
out = self.prompt_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len)
|
||||
add_out = self.add(cache, 1)
|
||||
sub_out = self.sub(add_out, 1)
|
||||
return sub_out
|
||||
|
||||
|
||||
def np_inference(cache, update, batch_index):
|
||||
def np_inference(cache, update, valid_seq_len, batch_index):
|
||||
"""
|
||||
np_inference
|
||||
"""
|
||||
zeros_ans = np.zeros(cache.shape, cache.dtype)
|
||||
us = update.shape[2]
|
||||
cache[batch_index, :] = zeros_ans[batch_index, :]
|
||||
cache[batch_index, :, 0:us, :] = update
|
||||
cache_rank = len(cache.shape)
|
||||
if cache_rank == 4:
|
||||
s_ = cache.shape[2]
|
||||
us_ = update.shape[2]
|
||||
elif cache_rank == 3:
|
||||
s_ = cache.shape[1]
|
||||
us_ = update.shape[1]
|
||||
|
||||
for i in range(batch_index.size):
|
||||
b_idx = batch_index[i]
|
||||
s_idx = valid_seq_len[i]
|
||||
if b_idx < 0:
|
||||
continue
|
||||
if s_idx < 0 or s_idx + us_ > s_:
|
||||
continue
|
||||
if cache_rank == 4:
|
||||
cache[b_idx, :, s_idx:s_idx + us_, :] = update[i]
|
||||
elif cache_rank == 3:
|
||||
cache[b_idx, s_idx:s_idx + us_, :] = update[i]
|
||||
|
||||
return cache
|
||||
|
||||
|
||||
|
@ -62,14 +86,19 @@ def create_numpy_inputs():
|
|||
"""
|
||||
create inputs
|
||||
"""
|
||||
cache = np.zeros([4, 2, 4, 32]).astype(np.float16)
|
||||
cache = cache + 9
|
||||
update = np.ones([1, 2, 1, 32]).astype(np.float16)
|
||||
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
|
||||
batch_index = np.array([1]).astype(np.int64)
|
||||
if is_4d:
|
||||
cache_shape = (b, h, s, d)
|
||||
update_shape = (ub, h, us, d)
|
||||
else:
|
||||
cache_shape = (b, s, h * d)
|
||||
update_shape = (ub, us, h * d)
|
||||
cache = np.random.rand(*cache_shape).astype(np.float16)
|
||||
update = np.random.rand(*update_shape).astype(np.float16)
|
||||
valid_seq_len = np.random.randint(-1, s, size=ub).astype(np.int64)
|
||||
batch_index = np.random.choice(np.arange(-1, b), size=ub, replace=False).astype(np.int64)
|
||||
seq_len_axis = np.array([2]).astype(np.int64)
|
||||
new_max_seq_len = np.array([4]).astype(np.int64)
|
||||
cur_max_seq_len = np.array([1]).astype(np.int64)
|
||||
new_max_seq_len = np.array([s]).astype(np.int64)
|
||||
cur_max_seq_len = np.array([s]).astype(np.int64)
|
||||
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
|
@ -89,13 +118,6 @@ def create_ms_inputs():
|
|||
ms_new_max_seq_len, ms_cur_max_seq_len)
|
||||
|
||||
|
||||
def create_np_inputs(cache, update, batch_index):
|
||||
"""
|
||||
create_np_inputs
|
||||
"""
|
||||
return (cache, update, batch_index)
|
||||
|
||||
|
||||
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
|
||||
"""
|
||||
create_lite_inputs
|
||||
|
@ -114,23 +136,26 @@ def inference_prompt_k_v_cache(mindir_model):
|
|||
"""
|
||||
inference model
|
||||
"""
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
|
||||
|
||||
lite_ctx1 = mslite.Context()
|
||||
lite_ctx1.target = ["ascend"]
|
||||
lite_ctx1.ascend.device_id = 1
|
||||
lite_ctx1.ascend.device_id = 0
|
||||
lite_ctx1.ascend.provider = "ge"
|
||||
|
||||
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len))
|
||||
model = mslite.Model()
|
||||
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
|
||||
model.build_from_file(
|
||||
mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
|
||||
|
||||
mslite_output = model.predict(input_lists)
|
||||
|
||||
np_cache, np_update, batch_index = create_np_inputs(cache, update, batch_index)
|
||||
expect_output = np_inference(np_cache, np_update, batch_index)
|
||||
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
|
||||
for i in range(50):
|
||||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, \
|
||||
cur_max_seq_len = create_numpy_inputs()
|
||||
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len, cur_max_seq_len))
|
||||
mslite_output = model.predict(input_lists)
|
||||
expect_output = np_inference(cache, update, valid_seq_len, batch_index)
|
||||
assert np.allclose(
|
||||
mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
|
||||
print(f"prompt_k_v_cache st {i} times: inference success.")
|
||||
|
||||
|
||||
def export_prompt_k_v_cache_model():
|
||||
|
@ -141,7 +166,10 @@ def export_prompt_k_v_cache_model():
|
|||
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
|
||||
|
||||
net = PromptKVCacheNet("right")
|
||||
file_name = "prompt_k_v_cache_primitive"
|
||||
if is_4d:
|
||||
file_name = "prompt_k_v_cache_primitive_4d"
|
||||
else:
|
||||
file_name = "prompt_k_v_cache_primitive_3d"
|
||||
|
||||
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
|
||||
file_name=file_name, file_format='MINDIR')
|
||||
|
@ -150,9 +178,24 @@ def export_prompt_k_v_cache_model():
|
|||
return model_name
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def test_4d():
|
||||
global is_4d
|
||||
is_4d = True
|
||||
model_path = export_prompt_k_v_cache_model()
|
||||
print("prompt_k_v_cache st : export success path: ", model_path)
|
||||
print("prompt_k_v_cache 4d st : export success path: ", model_path)
|
||||
|
||||
inference_prompt_k_v_cache(model_path)
|
||||
print("prompt_k_v_cache st : inference success.")
|
||||
print(f"prompt_k_v_cache 4d st : inference end.")
|
||||
|
||||
def test_3d():
|
||||
global is_4d
|
||||
is_4d = False
|
||||
model_path = export_prompt_k_v_cache_model()
|
||||
print("prompt_k_v_cache 3d st : export success path: ", model_path)
|
||||
|
||||
inference_prompt_k_v_cache(model_path)
|
||||
print(f"prompt_k_v_cache 3d st : inference end.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_3d()
|
||||
test_4d()
|
||||
|
|
|
@ -15,6 +15,7 @@ include_directories(${TOP_DIR}/graphengine/910b/inc)
|
|||
include_directories(${TOP_DIR}/graphengine/910b/inc/external)
|
||||
include_directories(${TOP_DIR}/graphengine/910b/third_party/fwkacllib)
|
||||
include_directories(${TOP_DIR}/graphengine/910b/third_party/fwkacllib/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} ops_srcs)
|
||||
opbuild(OPS_SRC ${ops_srcs} OUT_DIR ${ASCEND_AUTOGEN_PATH} INC_DIR ${BUILD_INC_910A_DIR}
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
|
||||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "decoder_kv_cache_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "platform_ascendc.h"
|
||||
|
||||
namespace {
|
||||
constexpr int index0 = 0;
|
||||
constexpr int index1 = 1;
|
||||
constexpr int index2 = 2;
|
||||
constexpr int index3 = 3;
|
||||
constexpr int index4 = 4;
|
||||
constexpr int index5 = 5;
|
||||
constexpr int index6 = 6;
|
||||
|
||||
constexpr int32_t kSize1 = 1;
|
||||
constexpr int32_t kSize2 = 2;
|
||||
constexpr int32_t kSize4 = 4;
|
||||
|
||||
constexpr int64_t kAxisOne = 1;
|
||||
constexpr int64_t kAxisTwo = 2;
|
||||
|
||||
constexpr size_t k910bWS = 16 * 1024 * 1024;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static ge::graphStatus TilingFunc(gert::TilingContext *context) {
|
||||
auto past_input = context->GetInputDesc(index0);
|
||||
auto dtype = past_input->GetDataType();
|
||||
auto type_size = ge::GetSizeByDataType(dtype);
|
||||
switch (type_size) {
|
||||
case kSize1:
|
||||
context->SetTilingKey(kSize1);
|
||||
break;
|
||||
case kSize2:
|
||||
context->SetTilingKey(kSize2);
|
||||
break;
|
||||
case kSize4:
|
||||
context->SetTilingKey(kSize4);
|
||||
break;
|
||||
default:
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
const gert::StorageShape *cur_shape = context->GetInputShape(index1);
|
||||
|
||||
bool is_dim4 = true;
|
||||
const size_t kDim3 = 3;
|
||||
const size_t kDim4 = 4;
|
||||
if (cur_shape->GetStorageShape().GetDimNum() == kDim4) {
|
||||
is_dim4 = true;
|
||||
} else if (cur_shape->GetStorageShape().GetDimNum() == kDim3) {
|
||||
is_dim4 = false;
|
||||
} else {
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
int64_t b = cur_shape->GetStorageShape().GetDim(index0);
|
||||
int64_t h = 0;
|
||||
int64_t us = 0;
|
||||
// s need get when run
|
||||
int64_t d = 0;
|
||||
|
||||
if (is_dim4) {
|
||||
// (b, h, us, d) -> (bs, us, d)
|
||||
h = cur_shape->GetStorageShape().GetDim(index1);
|
||||
us = cur_shape->GetStorageShape().GetDim(index2);
|
||||
d = cur_shape->GetStorageShape().GetDim(index3);
|
||||
} else {
|
||||
h = 1;
|
||||
us = cur_shape->GetStorageShape().GetDim(index1);
|
||||
d = cur_shape->GetStorageShape().GetDim(index2);
|
||||
}
|
||||
|
||||
auto platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
auto aiv_num = platform.GetCoreNumAiv();
|
||||
|
||||
context->SetBlockDim(aiv_num);
|
||||
|
||||
// set workspace for 910B
|
||||
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
|
||||
currentWorkspace[0] = k910bWS;
|
||||
|
||||
TilingData tiling;
|
||||
tiling.set_core_num(aiv_num);
|
||||
tiling.set_b(b);
|
||||
tiling.set_h(h);
|
||||
tiling.set_d(d);
|
||||
tiling.set_us(us);
|
||||
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckSupported(const ge::Operator &op, ge::AscendString &result) {
|
||||
std::string resultStr{};
|
||||
constexpr size_t input_num = 7;
|
||||
if (op.GetInputsSize() != input_num) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "input num is not 7"})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (op.GetOutputsSize() != 1) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "output num is not 1"})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (op.GetInputDesc(i).GetFormat() != ge::FORMAT_ND) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "input format is not supported, only support ND."})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t input_dim3_num = 3;
|
||||
const int64_t input_dim4_num = 4;
|
||||
if (op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim3_num &&
|
||||
op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim4_num) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "input dim is not supported, cache and update dim must be 4."})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (op.GetInputDesc(index2).GetShape().GetDimNum() != 1 || op.GetInputDesc(index3).GetShape().GetDimNum() != 1 ||
|
||||
op.GetInputDesc(index4).GetShape().GetDimNum() != 1 || op.GetInputDesc(index5).GetShape().GetDimNum() != 1 ||
|
||||
op.GetInputDesc(index6).GetShape().GetDimNum() != 1) {
|
||||
resultStr =
|
||||
R"({"ret_code": "1", "reason": "input dim is not supported, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len and cur_max_seq_len dim must be 1."})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
} // namespace optiling
|
||||
|
||||
namespace ge {
|
||||
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
|
||||
const gert::Shape *x1_shape = context->GetInputShape(0);
|
||||
gert::Shape *y_shape = context->GetOutputShape(0);
|
||||
*y_shape = *x1_shape;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
ge::graphStatus InferDecoderKvCacheDataType(gert::InferDataTypeContext *context) {
|
||||
const ge::DataType datatype = context->GetInputDataType(0);
|
||||
context->SetOutputDataType(0, datatype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
} // namespace ge
|
||||
|
||||
namespace ops {
|
||||
class DecoderKvCache : public OpDef {
|
||||
public:
|
||||
explicit DecoderKvCache(const char *name) : OpDef(name) {
|
||||
this->Input("cache")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("update")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("valid_seq_len")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("batch_index")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("seq_len_axis")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("new_max_seq_len")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("cur_max_seq_len")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->SetInferShape(ge::InferShape);
|
||||
this->SetInferDataType(ge::InferDecoderKvCacheDataType);
|
||||
|
||||
this->AICore()
|
||||
.SetTiling(optiling::TilingFunc)
|
||||
.AddConfig("ascend910")
|
||||
.AddConfig("ascend910b")
|
||||
.SetCheckSupport(optiling::CheckSupported);
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(DecoderKvCache);
|
||||
} // namespace ops
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DECODER_KV_CACHE_TILING_H
|
||||
#define DECODER_KV_CACHE_TILING_H
|
||||
#include "register/tilingdata_base.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(TilingData)
|
||||
TILING_DATA_FIELD_DEF(int64_t, core_num);
|
||||
TILING_DATA_FIELD_DEF(int64_t, b);
|
||||
TILING_DATA_FIELD_DEF(int64_t, h);
|
||||
TILING_DATA_FIELD_DEF(int64_t, d);
|
||||
TILING_DATA_FIELD_DEF(int64_t, us);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(DecoderKvCache, TilingData)
|
||||
} // namespace optiling
|
||||
#endif // DECODER_KV_CACHE_TILING_H
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PLATFORM_ASCENDC_H
|
||||
#define PLATFORM_ASCENDC_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace fe {
|
||||
class PlatFormInfos;
|
||||
}
|
||||
|
||||
namespace platform_ascendc {
|
||||
enum class CoreMemType { L0_A = 0, L0_B = 1, L0_C = 2, L1 = 3, L2 = 4, UB = 5, HBM = 6, RESERVED };
|
||||
|
||||
enum class SocVersion { ASCEND910 = 0, ASCEND910B, ASCEND310P, RESERVED_VERSION = 99999 };
|
||||
|
||||
class PlatformAscendC {
|
||||
public:
|
||||
PlatformAscendC() = delete;
|
||||
~PlatformAscendC() {}
|
||||
explicit PlatformAscendC(fe::PlatFormInfos *platformInfo) : platformInfo_(platformInfo) {}
|
||||
/**
|
||||
* Get Core Number
|
||||
* On Ascend910B MIX model, return AICore number
|
||||
* @return core number by core type
|
||||
*/
|
||||
uint32_t GetCoreNum(void) const;
|
||||
/**
|
||||
* Get Core Number AiCore
|
||||
* @return ai_core_num
|
||||
*/
|
||||
uint32_t GetCoreNumAic(void) const;
|
||||
/**
|
||||
* Get Core Number VectorCore
|
||||
* @return vector_core_num
|
||||
*/
|
||||
uint32_t GetCoreNumAiv(void) const;
|
||||
/**
|
||||
* Calc task schedule block dim
|
||||
* @sliceNum number slice of data division
|
||||
* @aicCoreNum value of GetCoreNumAic() if used cube API, otherwise 0
|
||||
* @aivCoreNum value of GetCoreNumAiv() if used vector API, otherwise 0
|
||||
* @return task schedule block dim
|
||||
*/
|
||||
uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) const;
|
||||
/**
|
||||
* Get Work Space Size
|
||||
* @return work sapce size by chip type
|
||||
*/
|
||||
uint32_t GetLibApiWorkSpaceSize(void) const;
|
||||
void GetCoreMemSize(const CoreMemType &memType, uint64_t &size) const;
|
||||
void GetCoreMemBw(const CoreMemType &memType, uint64_t &bwSize) const;
|
||||
/**
|
||||
* Get Soc Version Enum
|
||||
* @return Enum SocVersion
|
||||
*/
|
||||
SocVersion GetSocVersion(void) const;
|
||||
|
||||
private:
|
||||
fe::PlatFormInfos *platformInfo_;
|
||||
};
|
||||
} // namespace platform_ascendc
|
||||
#endif
|
|
@ -0,0 +1,247 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "prompt_kv_cache_tiling.h"
|
||||
#include "register/op_def_registry.h"
|
||||
#include "platform_ascendc.h"
|
||||
|
||||
namespace {
|
||||
constexpr int index0 = 0;
|
||||
constexpr int index1 = 1;
|
||||
constexpr int index2 = 2;
|
||||
constexpr int index3 = 3;
|
||||
constexpr int index4 = 4;
|
||||
constexpr int index5 = 5;
|
||||
constexpr int index6 = 6;
|
||||
|
||||
constexpr int32_t kSize1 = 1;
|
||||
constexpr int32_t kSize2 = 2;
|
||||
constexpr int32_t kSize4 = 4;
|
||||
|
||||
constexpr int64_t kAxisOne = 1;
|
||||
constexpr int64_t kAxisTwo = 2;
|
||||
|
||||
constexpr size_t k910bWS = 16 * 1024 * 1024;
|
||||
} // namespace
|
||||
|
||||
namespace optiling {
|
||||
static ge::graphStatus TilingFunc(gert::TilingContext *context) {
|
||||
auto past_input = context->GetInputDesc(index0);
|
||||
auto dtype = past_input->GetDataType();
|
||||
auto type_size = ge::GetSizeByDataType(dtype);
|
||||
switch (type_size) {
|
||||
case kSize1:
|
||||
context->SetTilingKey(kSize1);
|
||||
break;
|
||||
case kSize2:
|
||||
context->SetTilingKey(kSize2);
|
||||
break;
|
||||
case kSize4:
|
||||
context->SetTilingKey(kSize4);
|
||||
break;
|
||||
default:
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
const gert::StorageShape *cache_shape = context->GetInputShape(index0);
|
||||
const gert::StorageShape *update_shape = context->GetInputShape(index1);
|
||||
|
||||
bool is_dim4 = true;
|
||||
const size_t kDim3 = 3;
|
||||
const size_t kDim4 = 4;
|
||||
if (cache_shape->GetStorageShape().GetDimNum() == kDim4) {
|
||||
is_dim4 = true;
|
||||
} else if (cache_shape->GetStorageShape().GetDimNum() == kDim3) {
|
||||
is_dim4 = false;
|
||||
} else {
|
||||
return ge::GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
int64_t b = cache_shape->GetStorageShape().GetDim(index0);
|
||||
int64_t ub = update_shape->GetStorageShape().GetDim(index0);
|
||||
// s need get when run
|
||||
int64_t h = 0;
|
||||
int64_t s = 0;
|
||||
int64_t us = 0;
|
||||
int64_t d = 0;
|
||||
|
||||
if (is_dim4) {
|
||||
h = update_shape->GetStorageShape().GetDim(index1);
|
||||
s = cache_shape->GetStorageShape().GetDim(index2);
|
||||
us = update_shape->GetStorageShape().GetDim(index2);
|
||||
d = update_shape->GetStorageShape().GetDim(index3);
|
||||
} else {
|
||||
h = 1;
|
||||
s = cache_shape->GetStorageShape().GetDim(index1);
|
||||
us = update_shape->GetStorageShape().GetDim(index1);
|
||||
d = update_shape->GetStorageShape().GetDim(index2);
|
||||
}
|
||||
|
||||
auto platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
auto aiv_num = platform.GetCoreNumAiv();
|
||||
|
||||
context->SetBlockDim(aiv_num);
|
||||
|
||||
// set workspace for 910B
|
||||
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
|
||||
currentWorkspace[0] = k910bWS;
|
||||
|
||||
TilingData tiling;
|
||||
tiling.set_core_num(aiv_num);
|
||||
tiling.set_b(b);
|
||||
tiling.set_h(h);
|
||||
tiling.set_s(s);
|
||||
tiling.set_d(d);
|
||||
tiling.set_ub(ub);
|
||||
tiling.set_us(us);
|
||||
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
||||
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static ge::graphStatus CheckSupported(const ge::Operator &op, ge::AscendString &result) {
|
||||
std::string resultStr{};
|
||||
constexpr size_t input_num = 7;
|
||||
if (op.GetInputsSize() != input_num) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "input num is not 7"})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
if (op.GetOutputsSize() != 1) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "output num is not 1"})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (op.GetInputDesc(i).GetFormat() != ge::FORMAT_ND) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "input format is not supported, only support ND."})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t input_dim3_num = 3;
|
||||
const int64_t input_dim4_num = 4;
|
||||
if (op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim3_num &&
|
||||
op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim4_num) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "input dim is not supported, cache and update dim must be 4."})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (op.GetInputDesc(index1).GetShape().GetDimNum() != input_dim3_num &&
|
||||
op.GetInputDesc(index1).GetShape().GetDimNum() != input_dim4_num) {
|
||||
resultStr = R"({"ret_code": "1", "reason": "input dim is not supported, cache and update dim must be 4."})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (op.GetInputDesc(index2).GetShape().GetDimNum() != 1 || op.GetInputDesc(index3).GetShape().GetDimNum() != 1 ||
|
||||
op.GetInputDesc(index4).GetShape().GetDimNum() != 1 || op.GetInputDesc(index5).GetShape().GetDimNum() != 1 ||
|
||||
op.GetInputDesc(index6).GetShape().GetDimNum() != 1) {
|
||||
resultStr =
|
||||
R"({"ret_code": "1", "reason": "input dim is not supported, valid_seq_len, batch_index, seq_len_axis,
|
||||
new_max_seq_len and cur_max_seq_len dim must be 1."})";
|
||||
result = ge::AscendString(resultStr.c_str());
|
||||
return ge::GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
} // namespace optiling
|
||||
|
||||
namespace ge {
|
||||
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
|
||||
const gert::Shape *x1_shape = context->GetInputShape(0);
|
||||
gert::Shape *y_shape = context->GetOutputShape(0);
|
||||
*y_shape = *x1_shape;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
ge::graphStatus InferPromptKvCacheDataType(gert::InferDataTypeContext *context) {
|
||||
const ge::DataType datatype = context->GetInputDataType(0);
|
||||
context->SetOutputDataType(0, datatype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
} // namespace ge
|
||||
|
||||
namespace ops {
|
||||
class PromptKvCache : public OpDef {
|
||||
public:
|
||||
explicit PromptKvCache(const char *name) : OpDef(name) {
|
||||
this->Input("cache")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("update")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("valid_seq_len")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("batch_index")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("seq_len_axis")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("new_max_seq_len")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Input("cur_max_seq_len")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
this->Output("out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat(
|
||||
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
this->SetInferShape(ge::InferShape);
|
||||
this->SetInferDataType(ge::InferPromptKvCacheDataType);
|
||||
|
||||
this->AICore()
|
||||
.SetTiling(optiling::TilingFunc)
|
||||
.AddConfig("ascend910")
|
||||
.AddConfig("ascend910b")
|
||||
.SetCheckSupport(optiling::CheckSupported);
|
||||
}
|
||||
};
|
||||
|
||||
OP_ADD(PromptKvCache);
|
||||
} // namespace ops
|
|
@ -0,0 +1,34 @@
|
|||
|
||||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PROMPT_KV_CACHE_TILING_H
|
||||
#define PROMPT_KV_CACHE_TILING_H
|
||||
#include "register/tilingdata_base.h"
|
||||
|
||||
namespace optiling {
|
||||
BEGIN_TILING_DATA_DEF(TilingData)
|
||||
TILING_DATA_FIELD_DEF(int64_t, core_num);
|
||||
TILING_DATA_FIELD_DEF(int64_t, b);
|
||||
TILING_DATA_FIELD_DEF(int64_t, h);
|
||||
TILING_DATA_FIELD_DEF(int64_t, s);
|
||||
TILING_DATA_FIELD_DEF(int64_t, d);
|
||||
TILING_DATA_FIELD_DEF(int64_t, ub);
|
||||
TILING_DATA_FIELD_DEF(int64_t, us);
|
||||
END_TILING_DATA_DEF;
|
||||
REGISTER_TILING_DATA_CLASS(PromptKvCache, TilingData)
|
||||
} // namespace optiling
|
||||
#endif // PROMPT_KV_CACHE_TILING_H
|
|
@ -0,0 +1,196 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
using namespace AscendC;
|
||||
namespace {
|
||||
constexpr int32_t kBufferNum = 2;
|
||||
constexpr int64_t kUbSize = 192 * 1024;
|
||||
const int64_t kDivisor = 4;
|
||||
static __aicore__ inline int64_t CeilRound(int64_t value, int64_t kDivisor) {
|
||||
if (kDivisor == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (value + kDivisor - 1) / kDivisor * kDivisor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class KernelDecoderKvCache {
|
||||
public:
|
||||
__aicore__ inline KernelDecoderKvCache() {}
|
||||
|
||||
__aicore__ inline void GetNewMaxSeqLen(GM_ADDR new_max_seq_len) {
|
||||
new_max_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)new_max_seq_len, 4);
|
||||
pipe_.InitBuffer(new_max_seq_len_queue_, 1, CeilRound(1, kDivisor) * sizeof(int64_t));
|
||||
LocalTensor<int64_t> new_max_seq_len_tensor = new_max_seq_len_queue_.AllocTensor<int64_t>();
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
DataCopy(new_max_seq_len_tensor, new_max_seq_len_gm_, CeilRound(1, kDivisor));
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
s_ = new_max_seq_len_tensor.GetValue(0);
|
||||
new_max_seq_len_queue_.FreeTensor(new_max_seq_len_tensor);
|
||||
}
|
||||
|
||||
__aicore__ inline void GetValidSeqLen(GM_ADDR valid_seq_len, int64_t ub) {
|
||||
int64_t valid_seq_len_ub_size = CeilRound(ub, kDivisor);
|
||||
valid_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)valid_seq_len, valid_seq_len_ub_size);
|
||||
pipe_.InitBuffer(valid_seq_len_queue_, 1, valid_seq_len_ub_size * sizeof(int64_t));
|
||||
valid_seq_len_tensor_ = valid_seq_len_queue_.AllocTensor<int64_t>();
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
DataCopy(valid_seq_len_tensor_, valid_seq_len_gm_, valid_seq_len_ub_size);
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
remain_ub_size_ -= valid_seq_len_ub_size * sizeof(int64_t);
|
||||
}
|
||||
|
||||
__aicore__ inline void SplitBh(int64_t bh) {
|
||||
split_bh_ = 1;
|
||||
former_block_bh_ = bh;
|
||||
while (kBufferNum * former_block_bh_ * us_ * d_ * sizeof(T) >= remain_ub_size_) {
|
||||
split_bh_++;
|
||||
former_block_bh_ = (bh + split_bh_ - 1) / split_bh_;
|
||||
}
|
||||
tail_block_bh_ = bh - (split_bh_ - 1) * former_block_bh_;
|
||||
}
|
||||
|
||||
__aicore__ inline void Update(GM_ADDR cache, GM_ADDR update, LocalTensor<T> update_in_local_tensor) {
|
||||
for (int64_t i = 0; i < split_bh_; i++) {
|
||||
int64_t block_bh;
|
||||
if (i != split_bh_ - 1) {
|
||||
block_bh = former_block_bh_;
|
||||
} else {
|
||||
block_bh = tail_block_bh_;
|
||||
}
|
||||
update_gm_.SetGlobalBuffer((__gm__ T *)update + core_idx_ * update_core_stride_ + i * former_block_bh_ * us_ * d_,
|
||||
block_bh * us_ * d_);
|
||||
|
||||
DataCopy(update_in_local_tensor, update_gm_, block_bh * us_ * d_);
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
|
||||
update_queue_.EnQue(update_in_local_tensor);
|
||||
LocalTensor<T> update_in_local_tensor_out = update_queue_.DeQue<T>();
|
||||
|
||||
for (int64_t j = 0; j < block_bh; j++) {
|
||||
int64_t bh_idx = core_idx_ * former_bh_ + i * former_block_bh_ + j;
|
||||
auto b_idx = bh_idx / h_;
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
auto s_idx = valid_seq_len_tensor_.GetValue(b_idx);
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
if (s_idx < 0 || s_idx >= s_) {
|
||||
continue;
|
||||
}
|
||||
out_gm_.SetGlobalBuffer((__gm__ T *)cache + bh_idx * s_ * d_ + s_idx * d_, us_ * d_);
|
||||
int64_t src_offset = j * us_ * d_;
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
DataCopy(out_gm_, update_in_local_tensor_out[src_offset], us_ * d_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void Process(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len, GM_ADDR batch_index,
|
||||
GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len, int64_t core_num, int64_t b,
|
||||
int64_t h, int64_t d, int64_t us) {
|
||||
core_idx_ = GetBlockIdx();
|
||||
former_bh_ = (b * h + core_num - 1) / core_num;
|
||||
core_num_ = (b * h + former_bh_ - 1) / former_bh_;
|
||||
if (core_idx_ >= core_num_) {
|
||||
return;
|
||||
}
|
||||
tail_bh_ = b * h - (core_num_ - 1) * former_bh_;
|
||||
|
||||
b_ = b;
|
||||
h_ = h;
|
||||
d_ = d;
|
||||
us_ = us;
|
||||
|
||||
GetNewMaxSeqLen(new_max_seq_len);
|
||||
GetValidSeqLen(valid_seq_len, b);
|
||||
|
||||
update_core_stride_ = former_bh_ * us_ * d_;
|
||||
cache_core_stride_ = former_bh_ * s_ * d_;
|
||||
|
||||
if (core_idx_ != core_num - 1) {
|
||||
SplitBh(former_bh_);
|
||||
} else {
|
||||
SplitBh(tail_bh_);
|
||||
}
|
||||
|
||||
pipe_.InitBuffer(update_queue_, kBufferNum, former_block_bh_ * us_ * d_ * sizeof(T));
|
||||
LocalTensor<T> update_in_local_tensor = update_queue_.AllocTensor<T>();
|
||||
|
||||
Update(cache, update, update_in_local_tensor);
|
||||
|
||||
valid_seq_len_queue_.FreeTensor(valid_seq_len_tensor_);
|
||||
update_queue_.FreeTensor(update_in_local_tensor);
|
||||
}
|
||||
|
||||
private:
|
||||
// gm
|
||||
GlobalTensor<T> update_gm_;
|
||||
GlobalTensor<int64_t> valid_seq_len_gm_;
|
||||
GlobalTensor<int64_t> new_max_seq_len_gm_;
|
||||
GlobalTensor<T> out_gm_;
|
||||
|
||||
// local
|
||||
LocalTensor<int64_t> valid_seq_len_tensor_;
|
||||
|
||||
TPipe pipe_;
|
||||
// create queues for input, in this case depth is equal to buffer num
|
||||
TQue<QuePosition::VECIN, 1> update_queue_;
|
||||
TQue<QuePosition::VECIN, 1> valid_seq_len_queue_;
|
||||
TQue<QuePosition::VECIN, 1> new_max_seq_len_queue_;
|
||||
|
||||
int64_t remain_ub_size_ = kUbSize;
|
||||
|
||||
int64_t split_bh_ = 0;
|
||||
int64_t former_block_bh_ = 0;
|
||||
int64_t tail_block_bh_ = 0;
|
||||
|
||||
int64_t core_idx_ = 0;
|
||||
int64_t cache_core_stride_ = 0;
|
||||
int64_t cache_block_length = 0;
|
||||
int64_t update_core_stride_ = 0;
|
||||
int64_t update_block_length_ = 0;
|
||||
|
||||
int64_t core_num_ = 0;
|
||||
int64_t former_bh_ = 0;
|
||||
int64_t tail_bh_ = 0;
|
||||
int64_t b_ = 0;
|
||||
int64_t h_ = 0;
|
||||
int64_t s_ = 0;
|
||||
int64_t d_ = 0;
|
||||
int64_t us_ = 0;
|
||||
};
|
||||
|
||||
extern "C" __global__ __aicore__ void decoder_kv_cache(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len,
|
||||
GM_ADDR batch_index, GM_ADDR seq_len_axis,
|
||||
GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len, GM_ADDR out,
|
||||
GM_ADDR workspace, GM_ADDR tiling) {
|
||||
GET_TILING_DATA(tiling_data, tiling);
|
||||
|
||||
if (TILING_KEY_IS(1)) {
|
||||
KernelDecoderKvCache<int8_t> op;
|
||||
op.Process(cache, update, valid_seq_len, batch_index, new_max_seq_len, cur_max_seq_len, tiling_data.core_num,
|
||||
tiling_data.b, tiling_data.h, tiling_data.d, tiling_data.us);
|
||||
} else if (TILING_KEY_IS(2)) {
|
||||
KernelDecoderKvCache<int16_t> op;
|
||||
op.Process(cache, update, valid_seq_len, batch_index, new_max_seq_len, cur_max_seq_len, tiling_data.core_num,
|
||||
tiling_data.b, tiling_data.h, tiling_data.d, tiling_data.us);
|
||||
} else if (TILING_KEY_IS(4)) {
|
||||
KernelDecoderKvCache<int32_t> op;
|
||||
op.Process(cache, update, valid_seq_len, batch_index, new_max_seq_len, cur_max_seq_len, tiling_data.core_num,
|
||||
tiling_data.b, tiling_data.h, tiling_data.d, tiling_data.us);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,217 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_utils.h"
|
||||
using namespace AscendC;
|
||||
|
||||
namespace {
|
||||
constexpr int64_t kBufferNum = 2;
|
||||
constexpr int64_t kUbSize = 192 * 1024;
|
||||
const int64_t kDivisor = 4;
|
||||
static __aicore__ inline int64_t CeilRound(int64_t value, int64_t kDivisor) {
|
||||
if (kDivisor == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (value + kDivisor - 1) / kDivisor * kDivisor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class KernelPromptKvCache {
|
||||
public:
|
||||
__aicore__ inline KernelPromptKvCache() {}
|
||||
|
||||
__aicore__ inline void GetIndex(GM_ADDR batch_index, GM_ADDR valid_seq_len, int64_t ub) {
|
||||
int64_t batch_index_ub_num = CeilRound(ub, kDivisor);
|
||||
int64_t valid_seq_len_ub_num = CeilRound(ub, kDivisor);
|
||||
batch_index_gm_.SetGlobalBuffer((__gm__ int64_t *)batch_index, batch_index_ub_num);
|
||||
|
||||
int64_t total_num = batch_index_ub_num + valid_seq_len_ub_num;
|
||||
pipe_.InitBuffer(index_queue_, 1, total_num * sizeof(int64_t));
|
||||
batch_index_tensor_ = index_queue_.AllocTensor<int64_t>();
|
||||
DataCopy(batch_index_tensor_, batch_index_gm_, batch_index_ub_num);
|
||||
|
||||
valid_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)valid_seq_len, valid_seq_len_ub_num);
|
||||
valid_seq_len_tensor_ = batch_index_tensor_[batch_index_ub_num];
|
||||
DataCopy(valid_seq_len_tensor_, valid_seq_len_gm_, valid_seq_len_ub_num);
|
||||
|
||||
remain_ub_size_ -= total_num * sizeof(int64_t);
|
||||
}
|
||||
|
||||
__aicore__ inline void GetNewMaxSeqLen(GM_ADDR new_max_seq_len) {
|
||||
new_max_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)new_max_seq_len, 4);
|
||||
pipe_.InitBuffer(new_max_seq_len_queue_, 1, CeilRound(1, kDivisor) * sizeof(int64_t));
|
||||
LocalTensor<int64_t> new_max_seq_len_tensor = new_max_seq_len_queue_.AllocTensor<int64_t>();
|
||||
DataCopy(new_max_seq_len_tensor, new_max_seq_len_gm_, CeilRound(1, kDivisor));
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
s_ = new_max_seq_len_tensor.GetValue(0);
|
||||
new_max_seq_len_queue_.FreeTensor(new_max_seq_len_tensor);
|
||||
}
|
||||
|
||||
__aicore__ inline void UpdateCache(GM_ADDR cache, GM_ADDR update) {
|
||||
int64_t split_us = 1;
|
||||
int64_t block_us = us_ / split_us;
|
||||
while (kBufferNum * block_us * d_ * sizeof(T) >= remain_ub_size_) {
|
||||
split_us++;
|
||||
block_us = (us_ + split_us - 1) / split_us;
|
||||
}
|
||||
|
||||
int64_t former_block_us = block_us;
|
||||
int64_t tail_block_us = us_ - (split_us - 1) * former_block_us;
|
||||
pipe_.InitBuffer(update_queue_, kBufferNum, block_us * d_ * sizeof(T));
|
||||
|
||||
for (int64_t i = 0; i < each_core_bs_num_; ++i) {
|
||||
int64_t bh_idx = core_idx_ * former_each_core_bs_num_ + i;
|
||||
int64_t ub_idx = bh_idx / h_;
|
||||
int64_t h_idx = bh_idx % h_;
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
int64_t cache_b_idx = batch_index_tensor_.GetValue(ub_idx);
|
||||
int64_t s_idx = valid_seq_len_tensor_.GetValue(ub_idx);
|
||||
if (cache_b_idx < 0 || cache_b_idx >= b_) {
|
||||
continue;
|
||||
}
|
||||
if (s_idx < 0 || s_idx + us_ > s_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < split_us; ++j) {
|
||||
int64_t u_block_len;
|
||||
if (j == split_us - 1) {
|
||||
u_block_len = tail_block_us * d_;
|
||||
} else {
|
||||
u_block_len = former_block_us * d_;
|
||||
}
|
||||
LocalTensor<T> update_in_local_tensor = update_queue_.AllocTensor<T>();
|
||||
update_gm_.SetGlobalBuffer(
|
||||
(__gm__ T *)update + ub_idx * update_b_stride_ + h_idx * update_h_stride_ + j * former_block_us * d_,
|
||||
u_block_len);
|
||||
out_gm_.SetGlobalBuffer((__gm__ T *)cache + cache_b_idx * cache_b_stride_ + h_idx * cache_h_stride_ +
|
||||
s_idx * d_ + j * former_block_us * d_,
|
||||
u_block_len);
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
DataCopy(update_in_local_tensor, update_gm_, u_block_len);
|
||||
update_queue_.EnQue(update_in_local_tensor);
|
||||
LocalTensor<T> update_in_local_tensor_out = update_queue_.DeQue<T>();
|
||||
pipe_barrier((pipe_t)PIPE_ALL);
|
||||
DataCopy(out_gm_, update_in_local_tensor_out, u_block_len);
|
||||
update_queue_.FreeTensor(update_in_local_tensor_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void Process(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len, GM_ADDR batch_index,
|
||||
GM_ADDR seq_len_axis, GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len,
|
||||
int64_t max_core_num, int64_t b, int64_t h, int64_t s, int64_t d, int64_t ub,
|
||||
int64_t us) {
|
||||
core_idx_ = GetBlockIdx();
|
||||
int64_t bs = ub * h;
|
||||
former_each_core_bs_num_ = (bs + max_core_num - 1) / max_core_num;
|
||||
core_num_ = (bs + former_each_core_bs_num_ - 1) / former_each_core_bs_num_;
|
||||
tail_each_core_bs_num_ = bs - (core_num_ - 1) * former_each_core_bs_num_;
|
||||
|
||||
if (core_idx_ >= core_num_) {
|
||||
return;
|
||||
}
|
||||
if (g_coreType == AIC) {
|
||||
return;
|
||||
}
|
||||
|
||||
GetIndex(batch_index, valid_seq_len, ub);
|
||||
GetNewMaxSeqLen(new_max_seq_len);
|
||||
|
||||
b_ = b * s / s_;
|
||||
h_ = h;
|
||||
d_ = d;
|
||||
ub_ = ub;
|
||||
us_ = us;
|
||||
|
||||
if (core_idx_ != core_num_ - 1) {
|
||||
each_core_bs_num_ = former_each_core_bs_num_;
|
||||
} else {
|
||||
each_core_bs_num_ = tail_each_core_bs_num_;
|
||||
}
|
||||
|
||||
cache_h_stride_ = s_ * d_;
|
||||
cache_b_stride_ = h_ * cache_h_stride_;
|
||||
|
||||
update_h_stride_ = us_ * d_;
|
||||
update_b_stride_ = h_ * update_h_stride_;
|
||||
|
||||
UpdateCache(cache, update);
|
||||
index_queue_.FreeTensor(batch_index_tensor_);
|
||||
}
|
||||
|
||||
private:
|
||||
// gm
|
||||
GlobalTensor<T> update_gm_;
|
||||
GlobalTensor<int64_t> valid_seq_len_gm_;
|
||||
GlobalTensor<int64_t> batch_index_gm_;
|
||||
GlobalTensor<int64_t> new_max_seq_len_gm_;
|
||||
GlobalTensor<T> out_gm_;
|
||||
|
||||
// local gm
|
||||
LocalTensor<int64_t> valid_seq_len_tensor_;
|
||||
LocalTensor<int64_t> batch_index_tensor_;
|
||||
|
||||
TPipe pipe_;
|
||||
TQue<QuePosition::VECIN, 1> update_queue_;
|
||||
TQue<QuePosition::VECIN, 1> index_queue_;
|
||||
TQue<QuePosition::VECIN, 1> new_max_seq_len_queue_;
|
||||
|
||||
int64_t remain_ub_size_ = kUbSize;
|
||||
int64_t core_idx_ = 0;
|
||||
int64_t core_num_ = 0;
|
||||
int64_t each_core_bs_num_ = 0;
|
||||
int64_t former_each_core_bs_num_ = 0;
|
||||
int64_t tail_each_core_bs_num_ = 0;
|
||||
int64_t b_ = 0;
|
||||
int64_t h_ = 0;
|
||||
int64_t s_ = 0;
|
||||
int64_t d_ = 0;
|
||||
int64_t ub_ = 0;
|
||||
int64_t us_ = 0;
|
||||
int64_t ps_ = 0;
|
||||
|
||||
int64_t cache_b_stride_ = 0;
|
||||
int64_t cache_h_stride_ = 0;
|
||||
int64_t update_b_stride_ = 0;
|
||||
int64_t update_h_stride_ = 0;
|
||||
};
|
||||
|
||||
extern "C" __global__ __aicore__ void prompt_kv_cache(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len,
|
||||
GM_ADDR batch_index, GM_ADDR seq_len_axis,
|
||||
GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len, GM_ADDR out,
|
||||
GM_ADDR workspace, GM_ADDR tiling) {
|
||||
GET_TILING_DATA(tiling_data, tiling);
|
||||
|
||||
if (TILING_KEY_IS(1)) {
|
||||
KernelPromptKvCache<int8_t> op;
|
||||
op.Process(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
|
||||
tiling_data.core_num, tiling_data.b, tiling_data.h, tiling_data.s, tiling_data.d, tiling_data.ub,
|
||||
tiling_data.us);
|
||||
} else if (TILING_KEY_IS(2)) {
|
||||
KernelPromptKvCache<int16_t> op;
|
||||
op.Process(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
|
||||
tiling_data.core_num, tiling_data.b, tiling_data.h, tiling_data.s, tiling_data.d, tiling_data.ub,
|
||||
tiling_data.us);
|
||||
} else if (TILING_KEY_IS(4)) {
|
||||
KernelPromptKvCache<int32_t> op;
|
||||
op.Process(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
|
||||
tiling_data.core_num, tiling_data.b, tiling_data.h, tiling_data.s, tiling_data.d, tiling_data.ub,
|
||||
tiling_data.us);
|
||||
}
|
||||
}
|
|
@ -33,9 +33,9 @@ class DecoderKVCacheNet(Cell):
|
|||
new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def test_decoder_k_v_cache_net():
|
||||
def test_decoder_k_v_cache_net_4dims():
|
||||
"""
|
||||
Feature: test decoder_k_v_cache auto parallel
|
||||
Feature: test decoder_k_v_cache auto parallel 4dims
|
||||
Description: auto parallel
|
||||
Expectation: shape is as expected.
|
||||
"""
|
||||
|
@ -56,3 +56,28 @@ def test_decoder_k_v_cache_net():
|
|||
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
|
||||
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
|
||||
assert validator.check_parameter_shape('valid_seq_len', [1])
|
||||
|
||||
|
||||
def test_decoder_k_v_cache_net_3dims():
|
||||
"""
|
||||
Feature: test decoder_k_v_cache auto parallel 3dims
|
||||
Description: auto parallel
|
||||
Expectation: shape is as expected.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 1, 1,), (4, 1, 1,), (4,), (1,), (1,), (1,), (1,))
|
||||
net = DecoderKVCacheNet(strategy)
|
||||
cache = Parameter(Tensor(np.ones([4, 1024, 5120]), dtype=ms.float16), "cache")
|
||||
update = Parameter(Tensor(np.ones([4, 1, 5120]), dtype=ms.float16), "update")
|
||||
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
|
||||
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
|
||||
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
|
||||
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
|
||||
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
|
||||
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('cache', [1, 1024, 5120])
|
||||
assert validator.check_parameter_shape('update', [1, 1, 5120])
|
||||
assert validator.check_parameter_shape('valid_seq_len', [1])
|
||||
|
|
|
@ -33,9 +33,9 @@ class PromptKVCacheNet(Cell):
|
|||
new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
|
||||
def test_prompt_k_v_cache_net():
|
||||
def test_prompt_k_v_cache_net_dim4():
|
||||
"""
|
||||
Feature: test prompt_k_v_cache auto parallel
|
||||
Feature: test prompt_k_v_cache auto parallel 4 dims
|
||||
Description: auto parallel
|
||||
Expectation: shape is as expected.
|
||||
"""
|
||||
|
@ -57,3 +57,29 @@ def test_prompt_k_v_cache_net():
|
|||
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
|
||||
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
|
||||
assert validator.check_parameter_shape('valid_seq_len', [1])
|
||||
|
||||
|
||||
def test_prompt_k_v_cache_net_dim3():
|
||||
"""
|
||||
Feature: test prompt_k_v_cache auto parallel 4 dims
|
||||
Description: auto parallel
|
||||
Expectation: shape is as expected.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 1, 1,), (4, 1, 1,), (4,), (1,), (1,), (1,), (1,))
|
||||
padding_mode = "right"
|
||||
net = PromptKVCacheNet(padding_mode, strategy)
|
||||
cache = Parameter(Tensor(np.ones([4, 1024, 5120]), dtype=ms.float16), "cache")
|
||||
update = Parameter(Tensor(np.ones([4, 1, 5120]), dtype=ms.float16), "update")
|
||||
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
|
||||
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
|
||||
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
|
||||
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
|
||||
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
|
||||
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
|
||||
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('cache', [1, 1024, 5120])
|
||||
assert validator.check_parameter_shape('update', [1, 1, 5120])
|
||||
assert validator.check_parameter_shape('valid_seq_len', [1])
|
||||
|
|
Loading…
Reference in New Issue