add PromptKVCache and DecoderKVCache ascend op

inplace update for ge

fix bug of prompt kv cache

fix buf of kernel sync problem

prompt kvcache add multi batch index
prompt kvcache and decoder kvcache support multi data type
decoder kvcache jump -1 seq_len
prompt kvcache jump -1 batch index

fix bug of decoder_kv_cache not support dig batch size

prompt_kv_cache support valid_seq_len

prompt_kv_cache and decoder_kv_cache support BSD format

make prompt_kv_cache and decoder_kv_cache support BSD format
This commit is contained in:
mengyuanli 2023-10-26 09:43:27 +08:00
parent 2c6ff963f0
commit 58a648eaba
18 changed files with 1521 additions and 161 deletions

View File

@ -98,6 +98,10 @@
"mindspore/mindspore/lite/src/extendrt/cxx_api/model/model_impl.cc" "whitespace/parens"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/gemm_mask_avx512/" "runtime/int"
"mindspore/mindspore/lite/python/src/lite_infer_pybind.cc" "runtime/references"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/op_host/decoder_kv_cache.cpp" "runtime/references"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/op_host/platform_ascendc.h" "runtime/references"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/op_host/prompt_kv_cache.cpp" "runtime/references"
# ascend samples
"mindspore/mindspore/lite/tools/kernel_builder/ascend/aicpu/sample/" "build/include_subdir"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "build/include_subdir"

View File

@ -20,28 +20,107 @@
namespace mindspore {
namespace parallel {
// DecoderKVCache seven inputs
// cache: (batch_size, num_head, max_seq_len, hidden_size)
// update: (batch_size, num_head, update_seq_len, hidden_size)
// ===== Case 1:
// cache and update shape rank is 4.
// cache: (batch, num_head, max_seq_len, hidden_size)
// update: (batch, num_head, update_seq_len, hidden_size)
// valid_seq_len: (batch)
// batch_index: (1)
// seq_len_axis: (1)
// new_max_seq_len: (1)
// cur_max_seq_len: (1)
// ------------------------------
// output: (batch_size, num_head, max_seq_len, hidden_size)
// output: (batch, num_head, max_seq_len, hidden_size)
// split strategy
// batch_size is able to split.
// batch is able to split.
// max_seq_len, update_seq_len is not able to split.
// num_head is able to split.
// hidden_size is able to split.
// ===== Case 2:
// cache and update shape rank is 3.
// cache: (batch, max_seq_len, hidden_size)
// update: (batch, update_seq_len, hidden_size)
// valid_seq_len: (batch)
// batch_index: (1)
// seq_len_axis: (1)
// new_max_seq_len: (1)
// cur_max_seq_len: (1)
// ------------------------------
// output: (batch, max_seq_len, hidden_size)
// split strategy
// batch is able to split.
// max_seq_len, update_seq_len is not able to split.
// hidden_size is able to split.
Status DecoderKVCacheInfo::CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
if (strategy_cache.at(1) != 1 || strategy_update.at(1) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(1)
<< "; update's seq_len strategy: " << strategy_update.at(1);
return FAILED;
}
if (strategy_cache.at(2) != strategy_update.at(2)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
}
Status DecoderKVCacheInfo::CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
// num_head must be the same strategy.
if (strategy_cache.at(1) != strategy_cache.at(1)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(2)
<< "; update's seq_len strategy: " << strategy_update.at(2);
return FAILED;
}
// hidden_size must be the same strategy.
if (strategy_cache.at(3) != strategy_update.at(3)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
}
Status DecoderKVCacheInfo::SetDims(const StrategyPtr &strategy) {
auto input_strategys = strategy->GetInputDim();
auto strategy_cache = input_strategys.at(0);
const size_t input_dims4 = 4;
const size_t input_dims3 = 3;
if (strategy_cache.size() == input_dims4) {
is_input_dims_4_ = true;
} else if (strategy_cache.size() == input_dims3) {
is_input_dims_4_ = false;
} else {
return FAILED;
}
return SUCCESS;
}
Status DecoderKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
}
auto input_strategys = strategy->GetInputDim();
if (SetDims(strategy) != SUCCESS) {
return FAILED;
}
auto input_strategys = strategy->GetInputDim();
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
@ -74,69 +153,77 @@ Status DecoderKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(2)
<< "; update's seq_len strategy: " << strategy_update.at(2);
return FAILED;
}
// batch_size must be the same strategy.
// batch must be the same strategy.
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
return FAILED;
}
// num_head must be the same strategy.
if (strategy_cache.at(1) != strategy_cache.at(1)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
if (is_input_dims_4_) {
return CheckStrategy4Dims(strategy_cache, strategy_update);
}
// hidden_size must be the same strategy.
if (strategy_cache.at(3) != strategy_update.at(3)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
return CheckStrategy3Dims(strategy_cache, strategy_update);
}
Status DecoderKVCacheInfo::InferDevMatrixShape() {
auto input_strategys = strategy()->GetInputDim();
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
auto cache = input_strategys.at(0); // batch num_head max_seq_len hidden_size
auto update = input_strategys.at(1); // batch num_head update_seq_len hidden_size
if (is_input_dims_4_) {
// cache shape (batch num_head max_seq_len hidden_size)
// update shape (batch num_head update_seq_len hidden_size)
// update_seq_len batch num_head max_seq_len hidden_size
// 4 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
} else {
// cache shape (batch max_seq_len hidden_size)
// update shape (batch update_seq_len hidden_size)
// update_seq_len batch max_seq_len hidden_size
// 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2)};
}
// update_seq_len batch_size num_head max_seq_len hidden_size
// 4 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
return SUCCESS;
}
Status DecoderKVCacheInfo::InferTensorMap() {
Shape cache_tensor_map{3, 2, 1, 0};
Shape update_tensor_map{3, 2, 4, 0};
Shape valid_seq_len_tensor_map{3};
if (is_input_dims_4_) {
Shape cache_tensor_map{3, 2, 1, 0};
Shape update_tensor_map{3, 2, 4, 0};
Shape valid_seq_len_tensor_map{3};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
} else {
Shape cache_tensor_map{2, 1, 0};
Shape update_tensor_map{2, 3, 0};
Shape valid_seq_len_tensor_map{2};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
}
Shape batch_index_tensor_map{-1};
Shape seq_lqn_axis_tensor_map{-1};
Shape new_max_seq_len_tensor_map{-1};
Shape cur_max_seq_len_tensor_map{-1};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
Shape out_tensor_map{3, 2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
if (is_input_dims_4_) {
Shape out_tensor_map{3, 2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
} else {
Shape out_tensor_map{2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
}
return SUCCESS;
}
REGISTER(DecoderKVCacheInfo);

View File

@ -44,6 +44,12 @@ class DecoderKVCacheInfo : public OperatorInfo {
Status InferForwardCommunication() { return SUCCESS; }
Status InferTensorMap() override;
Status InferDevMatrixShape() override;
Status SetDims(const StrategyPtr &strategy);
Status CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
Status CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
private:
bool is_input_dims_4_ = true;
};
using DecoderKVCacheInfoPtr = std::shared_ptr<DecoderKVCacheInfo>;
} // namespace parallel

View File

@ -20,29 +20,107 @@
namespace mindspore {
namespace parallel {
// PromptKVCache seven inputs
// cache: (batch_size, num_head, max_seq_len, hidden_size)
// update: (batch_size, num_head, update_seq_len, hidden_size)
// ===== Case 1:
// cache and update shape rank is 4.
// cache: (batch, num_head, max_seq_len, hidden_size)
// update: (batch, num_head, update_seq_len, hidden_size)
// valid_seq_len: (batch)
// batch_index: (1)
// seq_len_axis: (1)
// new_max_seq_len: (1)
// cur_max_seq_len: (1)
// ------------------------------
// output: (batch_size, num_head, max_seq_len, hidden_size)
// output: (batch, num_head, max_seq_len, hidden_size)
// split strategy
// batch_size is able to split.
// batch is able to split.
// max_seq_len, update_seq_len is not able to split.
// num_head is able to split.
// hidden_size is able to split.
// ====== Case 2:
// cache and update shape rank is 3.
// cache: (batch, max_seq_len, hidden_size)
// update: (batch, update_seq_len, hidden_size)
// valid_seq_len: (batch)
// batch_index: (1)
// seq_len_axis: (1)
// new_max_seq_len: (1)
// cur_max_seq_len: (1)
// ------------------------------
// output: (batch, max_seq_len, hidden_size)
// split strategy
// batch is able to split.
// max_seq_len, update_seq_len is not able to split.
// hidden_size is able to split.
Status PromptKVCacheInfo::CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
if (strategy_cache.at(1) != 1 || strategy_update.at(1) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(1)
<< "; update's seq_len strategy: " << strategy_update.at(1);
return FAILED;
}
if (strategy_cache.at(2) != strategy_update.at(2)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
}
Status PromptKVCacheInfo::CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update) {
// num_head must be the same strategy.
if (strategy_cache.at(1) != strategy_cache.at(1)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(2)
<< "; update's seq_len strategy: " << strategy_update.at(2);
return FAILED;
}
// hidden_size must be the same strategy.
if (strategy_cache.at(3) != strategy_update.at(3)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
}
Status PromptKVCacheInfo::SetDims(const StrategyPtr &strategy) {
auto input_strategys = strategy->GetInputDim();
auto strategy_cache = input_strategys.at(0);
const size_t input_dims4 = 4;
const size_t input_dims3 = 3;
if (strategy_cache.size() == input_dims4) {
is_input_dims_4_ = true;
} else if (strategy_cache.size() == input_dims3) {
is_input_dims_4_ = false;
} else {
return FAILED;
}
return SUCCESS;
}
Status PromptKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
}
if (SetDims(strategy) != SUCCESS) {
return FAILED;
}
auto input_strategys = strategy->GetInputDim();
auto strategy_cache = input_strategys.at(0); // (4, 4, 1, 4)
auto strategy_update = input_strategys.at(1); // (4, 4, 1, 4)
auto strategy_valid_seq_len = input_strategys.at(2); // (4)
@ -75,69 +153,78 @@ Status PromptKVCacheInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
if (strategy_cache.at(2) != 1 || strategy_update.at(2) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy: The seq_len can't be shard, but got"
<< " cache's seq_len strategy: " << strategy_cache.at(2)
<< "; update's seq_len strategy: " << strategy_update.at(2);
return FAILED;
}
// batch_size must be the same strategy.
// batch must be the same strategy.
if (strategy_cache.at(0) != strategy_update.at(0) || strategy_cache.at(0) != strategy_valid_seq_len.at(0)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch_size must be shard at the same time, but got"
MS_LOG(ERROR) << name_ << " Invalid strategy: The batch must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update
<< ", strategy_valid_seq_len's strategy: " << strategy_valid_seq_len;
return FAILED;
}
// num_head must be the same strategy.
if (strategy_cache.at(1) != strategy_cache.at(1)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The num_head must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
if (is_input_dims_4_) {
return CheckStrategy4Dims(strategy_cache, strategy_update);
}
// hidden_size must be the same strategy.
if (strategy_cache.at(3) != strategy_update.at(3)) {
MS_LOG(ERROR) << name_ << " Invalid strategy: The hidden_size must be shard at the same time, but got"
<< " strategy_cache's strategy: " << strategy_cache
<< ", strategy_update's strategy: " << strategy_update;
return FAILED;
}
return SUCCESS;
return CheckStrategy3Dims(strategy_cache, strategy_update);
}
Status PromptKVCacheInfo::InferDevMatrixShape() {
auto input_strategys = strategy()->GetInputDim();
auto cache = input_strategys.at(0); // batch_size num_head max_seq_len hidden_size
auto update = input_strategys.at(1); // batch_size num_head update_seq_len hidden_size
// update_seq_len batch_size num_head max_seq_len hidden_size
// 4 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
auto cache = input_strategys.at(0);
auto update = input_strategys.at(1);
if (is_input_dims_4_) {
// cache shape (batch num_head max_seq_len hidden_size)
// update shape (batch num_head update_seq_len hidden_size)
// update_seq_len batch num_head max_seq_len hidden_size
// 4 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2), cache.at(3)};
} else {
// cache shape (batch max_seq_len hidden_size)
// update shape (batch update_seq_len hidden_size)
// update_seq_len batch max_seq_len hidden_size
// 3 2 1 0
dev_matrix_shape_ = {update.at(2), cache.at(0), cache.at(1), cache.at(2)};
}
return SUCCESS;
}
Status PromptKVCacheInfo::InferTensorMap() {
Shape cache_tensor_map{3, 2, 1, 0};
Shape update_tensor_map{3, 2, 4, 0};
Shape valid_seq_len_tensor_map{3};
if (is_input_dims_4_) {
Shape cache_tensor_map{3, 2, 1, 0};
Shape update_tensor_map{3, 2, 4, 0};
Shape valid_seq_len_tensor_map{3};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
} else {
Shape cache_tensor_map{2, 1, 0};
Shape update_tensor_map{2, 3, 0};
Shape valid_seq_len_tensor_map{2};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
}
Shape batch_index_tensor_map{-1};
Shape seq_lqn_axis_tensor_map{-1};
Shape new_max_seq_len_tensor_map{-1};
Shape cur_max_seq_len_tensor_map{-1};
inputs_tensor_map_.emplace_back(cache_tensor_map);
inputs_tensor_map_.emplace_back(update_tensor_map);
inputs_tensor_map_.emplace_back(valid_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(batch_index_tensor_map);
inputs_tensor_map_.emplace_back(seq_lqn_axis_tensor_map);
inputs_tensor_map_.emplace_back(new_max_seq_len_tensor_map);
inputs_tensor_map_.emplace_back(cur_max_seq_len_tensor_map);
Shape out_tensor_map{3, 2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
if (is_input_dims_4_) {
Shape out_tensor_map{3, 2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
} else {
Shape out_tensor_map{2, 1, 0};
outputs_tensor_map_.emplace_back(out_tensor_map);
}
return SUCCESS;
}
REGISTER(PromptKVCacheInfo);

View File

@ -45,6 +45,12 @@ class PromptKVCacheInfo : public OperatorInfo {
Status InferForwardCommunication() { return SUCCESS; }
Status InferTensorMap() override;
Status InferDevMatrixShape() override;
Status SetDims(const StrategyPtr &strategy);
Status CheckStrategy3Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
Status CheckStrategy4Dims(const Dimensions &strategy_cache, const Dimensions &strategy_update);
private:
bool is_input_dims_4_ = true;
};
using PromptKVCacheInfoPtr = std::shared_ptr<PromptKVCacheInfo>;
} // namespace parallel

View File

@ -613,7 +613,8 @@ bool DfGraphConvertor::NodeInputKeepUpdate(const FuncGraphManagerPtr &manager, c
return true;
}
const auto &node_users = manager->node_users();
std::vector<PrimitivePtr> vec{prim::kPrimAssign, prim::kPrimKVCacheMgr, prim::kPrimScatterUpdate};
std::vector<PrimitivePtr> vec{prim::kPrimAssign, prim::kPrimKVCacheMgr, prim::kPrimScatterUpdate,
prim::kPrimPromptKVCache, prim::kPrimDecoderKVCache};
auto user_it = node_users.find(node);
if (user_it != node_users.end()) {
auto &users = user_it->second;

View File

@ -22,36 +22,51 @@ import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, context
from mindspore.train.serialization import export
from mindspore.ops.operations._inner_ops import DecoderKVCache
from mindspore.ops.auto_generate.gen_inner_ops_def import DecoderKVCache
b = 26
h = 40
s = 32
d = 128
us = 1
ps = s - us
is_4d = True
class DecoderKVCacheNet(nn.Cell):
"""
DecoderKVCacheNet.
"""
def __init__(self):
super().__init__()
self.add = ops.Add()
self.sub = ops.Sub()
self.decoder_k_v_cache = DecoderKVCache()
self.seq_len_axis = [2, 0, 0, 0]
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
out = self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, self.seq_len_axis,
out = self.decoder_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len)
add_out = self.add(cache, 1)
sub_out = self.sub(add_out, 1)
return sub_out
def np_inference(cache, update, valid_seq_len):
def np_inference(cache, update, valid_seq_len):
"""
np_inference
"""
ans = cache.copy()
for b in range(cache.shape[0]):
bs_idx = valid_seq_len[b]
ans[b, :, bs_idx, :] = update[b, :, 0, :]
for b_idx in range(cache.shape[0]):
s_idx = valid_seq_len[b_idx]
if s_idx < 0:
continue
if is_4d:
ans[b_idx, :, s_idx, :] = update[b_idx, :, 0, :]
else:
ans[b_idx, s_idx, :] = update[b_idx, 0, :]
return ans
@ -59,13 +74,20 @@ def create_numpy_inputs():
"""
create inputs
"""
cache = np.random.rand(4, 2, 4, 32).astype(np.float16)
update = np.random.rand(4, 2, 1, 32).astype(np.float16)
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
global is_4d
if is_4d:
cache_shape = (b, h, s, d)
update_shape = (b, h, us, d)
else:
cache_shape = (b, s, h * d)
update_shape = (b, us, h * d)
cache = np.random.rand(*cache_shape).astype(np.float16)
update = np.random.rand(*update_shape).astype(np.float16)
valid_seq_len = np.random.randint(-1, s, size=b).astype(np.int64)
batch_index = np.array([1]).astype(np.int64)
seq_len_axis = np.array([2]).astype(np.int64)
new_max_seq_len = np.array([4]).astype(np.int64)
cur_max_seq_len = np.array([1]).astype(np.int64)
new_max_seq_len = np.array([s]).astype(np.int64)
cur_max_seq_len = np.array([s]).astype(np.int64)
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
@ -85,13 +107,6 @@ def create_ms_inputs():
ms_new_max_seq_len, ms_cur_max_seq_len)
def create_np_inputs(cache, update, valid_seq_len):
"""
create_np_inputs
"""
return (cache, update, valid_seq_len)
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
"""
create_lite_inputs
@ -110,23 +125,26 @@ def inference_decoder_k_v_cache(mindir_model):
"""
inference model
"""
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
lite_ctx1 = mslite.Context()
lite_ctx1.target = ["ascend"]
lite_ctx1.ascend.device_id = 1
lite_ctx1.ascend.device_id = 0
lite_ctx1.ascend.provider = "ge"
model = mslite.Model()
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
model.build_from_file(
mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len))
mslite_output = model.predict(input_lists)
np_cache, np_update, np_valid_seq_len = create_np_inputs(cache, update, valid_seq_len)
expect_output = np_inference(np_cache, np_update, np_valid_seq_len)
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
for i in range(50):
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, \
cur_max_seq_len = create_numpy_inputs()
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len))
mslite_output = model.predict(input_lists)
expect_output = np_inference(cache, update, valid_seq_len)
assert np.allclose(
mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
print(f"decoder_k_v_cache st {i} times: inference success.")
def export_decoder_k_v_cache_model():
@ -137,7 +155,10 @@ def export_decoder_k_v_cache_model():
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
net = DecoderKVCacheNet()
file_name = "decoder_k_v_cache_primitive"
if is_4d:
file_name = "decoder_k_v_cache_primitive_4d"
else:
file_name = "decoder_k_v_cache_primitive_3d"
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
file_name=file_name, file_format='MINDIR')
@ -146,9 +167,25 @@ def export_decoder_k_v_cache_model():
return model_name
if __name__ == '__main__':
def test_4d():
global is_4d
is_4d = True
model_path = export_decoder_k_v_cache_model()
print("decoder_k_v_cache st : export success path: ", model_path)
print("decoder_k_v_cache 4d st : export success path: ", model_path)
inference_decoder_k_v_cache(model_path)
print("decoder_k_v_cache st : inference success.")
print(f"decoder_k_v_cache 4d st : inference end.")
def test_3d():
global is_4d
is_4d = False
model_path = export_decoder_k_v_cache_model()
print("decoder_k_v_cache 3d st : export success path: ", model_path)
inference_decoder_k_v_cache(model_path)
print(f"decoder_k_v_cache 3d st : inference end.")
if __name__ == '__main__':
test_3d()
test_4d()

View File

@ -20,41 +20,65 @@ import numpy as np
import mindspore_lite as mslite
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, context, Parameter
from mindspore import Tensor, context
from mindspore.train.serialization import export
import mindspore.common.dtype as mstype
from mindspore.ops.operations._inner_ops import PromptKVCache
from mindspore.ops.auto_generate.gen_inner_ops_def import PromptKVCache
b = 40
h = 4
s = 1024
d = 32
ub = 40
us = 512
ps = s - us
is_4d = True
class PromptKVCacheNet(nn.Cell):
"""
PromptKVCacheNet.
"""
def __init__(self, padding_mode):
super().__init__()
self.sub = ops.Sub()
self.add = ops.Add()
self.concat_dim2 = ops.Concat(axis=2)
self.prompt_k_v_cache = PromptKVCache(padding_mode)
self.pad_update_zero_tensor = Parameter(ops.zeros((1, 2, 3, 32), mstype.float16))
def construct(self, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
update_pad = self.concat_dim2((update, self.pad_update_zero_tensor))
out = self.prompt_k_v_cache(cache, update_pad, valid_seq_len, batch_index, seq_len_axis,
out = self.prompt_k_v_cache(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len)
add_out = self.add(cache, 1)
sub_out = self.sub(add_out, 1)
return sub_out
def np_inference(cache, update, batch_index):
def np_inference(cache, update, valid_seq_len, batch_index):
"""
np_inference
"""
zeros_ans = np.zeros(cache.shape, cache.dtype)
us = update.shape[2]
cache[batch_index, :] = zeros_ans[batch_index, :]
cache[batch_index, :, 0:us, :] = update
cache_rank = len(cache.shape)
if cache_rank == 4:
s_ = cache.shape[2]
us_ = update.shape[2]
elif cache_rank == 3:
s_ = cache.shape[1]
us_ = update.shape[1]
for i in range(batch_index.size):
b_idx = batch_index[i]
s_idx = valid_seq_len[i]
if b_idx < 0:
continue
if s_idx < 0 or s_idx + us_ > s_:
continue
if cache_rank == 4:
cache[b_idx, :, s_idx:s_idx + us_, :] = update[i]
elif cache_rank == 3:
cache[b_idx, s_idx:s_idx + us_, :] = update[i]
return cache
@ -62,14 +86,19 @@ def create_numpy_inputs():
"""
create inputs
"""
cache = np.zeros([4, 2, 4, 32]).astype(np.float16)
cache = cache + 9
update = np.ones([1, 2, 1, 32]).astype(np.float16)
valid_seq_len = np.array([0, 1, 2, 3]).astype(np.int64)
batch_index = np.array([1]).astype(np.int64)
if is_4d:
cache_shape = (b, h, s, d)
update_shape = (ub, h, us, d)
else:
cache_shape = (b, s, h * d)
update_shape = (ub, us, h * d)
cache = np.random.rand(*cache_shape).astype(np.float16)
update = np.random.rand(*update_shape).astype(np.float16)
valid_seq_len = np.random.randint(-1, s, size=ub).astype(np.int64)
batch_index = np.random.choice(np.arange(-1, b), size=ub, replace=False).astype(np.int64)
seq_len_axis = np.array([2]).astype(np.int64)
new_max_seq_len = np.array([4]).astype(np.int64)
cur_max_seq_len = np.array([1]).astype(np.int64)
new_max_seq_len = np.array([s]).astype(np.int64)
cur_max_seq_len = np.array([s]).astype(np.int64)
return (cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
@ -89,13 +118,6 @@ def create_ms_inputs():
ms_new_max_seq_len, ms_cur_max_seq_len)
def create_np_inputs(cache, update, batch_index):
"""
create_np_inputs
"""
return (cache, update, batch_index)
def create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len):
"""
create_lite_inputs
@ -114,23 +136,26 @@ def inference_prompt_k_v_cache(mindir_model):
"""
inference model
"""
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_numpy_inputs()
lite_ctx1 = mslite.Context()
lite_ctx1.target = ["ascend"]
lite_ctx1.ascend.device_id = 1
lite_ctx1.ascend.device_id = 0
lite_ctx1.ascend.provider = "ge"
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len))
model = mslite.Model()
model.build_from_file(mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
model.build_from_file(
mindir_model, mslite.ModelType.MINDIR, lite_ctx1, "", {})
mslite_output = model.predict(input_lists)
np_cache, np_update, batch_index = create_np_inputs(cache, update, batch_index)
expect_output = np_inference(np_cache, np_update, batch_index)
assert np.allclose(mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
for i in range(50):
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, \
cur_max_seq_len = create_numpy_inputs()
input_lists = list(create_lite_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len, cur_max_seq_len))
mslite_output = model.predict(input_lists)
expect_output = np_inference(cache, update, valid_seq_len, batch_index)
assert np.allclose(
mslite_output[0].get_data_to_numpy(), expect_output, 0.001, 0.001)
print(f"prompt_k_v_cache st {i} times: inference success.")
def export_prompt_k_v_cache_model():
@ -141,7 +166,10 @@ def export_prompt_k_v_cache_model():
cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len = create_ms_inputs()
net = PromptKVCacheNet("right")
file_name = "prompt_k_v_cache_primitive"
if is_4d:
file_name = "prompt_k_v_cache_primitive_4d"
else:
file_name = "prompt_k_v_cache_primitive_3d"
export(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
file_name=file_name, file_format='MINDIR')
@ -150,9 +178,24 @@ def export_prompt_k_v_cache_model():
return model_name
if __name__ == '__main__':
def test_4d():
global is_4d
is_4d = True
model_path = export_prompt_k_v_cache_model()
print("prompt_k_v_cache st : export success path: ", model_path)
print("prompt_k_v_cache 4d st : export success path: ", model_path)
inference_prompt_k_v_cache(model_path)
print("prompt_k_v_cache st : inference success.")
print(f"prompt_k_v_cache 4d st : inference end.")
def test_3d():
global is_4d
is_4d = False
model_path = export_prompt_k_v_cache_model()
print("prompt_k_v_cache 3d st : export success path: ", model_path)
inference_prompt_k_v_cache(model_path)
print(f"prompt_k_v_cache 3d st : inference end.")
if __name__ == '__main__':
test_3d()
test_4d()

View File

@ -15,6 +15,7 @@ include_directories(${TOP_DIR}/graphengine/910b/inc)
include_directories(${TOP_DIR}/graphengine/910b/inc/external)
include_directories(${TOP_DIR}/graphengine/910b/third_party/fwkacllib)
include_directories(${TOP_DIR}/graphengine/910b/third_party/fwkacllib/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} ops_srcs)
opbuild(OPS_SRC ${ops_srcs} OUT_DIR ${ASCEND_AUTOGEN_PATH} INC_DIR ${BUILD_INC_910A_DIR}

View File

@ -0,0 +1,235 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_kv_cache_tiling.h"
#include "register/op_def_registry.h"
#include "platform_ascendc.h"
namespace {
constexpr int index0 = 0;
constexpr int index1 = 1;
constexpr int index2 = 2;
constexpr int index3 = 3;
constexpr int index4 = 4;
constexpr int index5 = 5;
constexpr int index6 = 6;
constexpr int32_t kSize1 = 1;
constexpr int32_t kSize2 = 2;
constexpr int32_t kSize4 = 4;
constexpr int64_t kAxisOne = 1;
constexpr int64_t kAxisTwo = 2;
constexpr size_t k910bWS = 16 * 1024 * 1024;
} // namespace
namespace optiling {
static ge::graphStatus TilingFunc(gert::TilingContext *context) {
auto past_input = context->GetInputDesc(index0);
auto dtype = past_input->GetDataType();
auto type_size = ge::GetSizeByDataType(dtype);
switch (type_size) {
case kSize1:
context->SetTilingKey(kSize1);
break;
case kSize2:
context->SetTilingKey(kSize2);
break;
case kSize4:
context->SetTilingKey(kSize4);
break;
default:
return ge::GRAPH_PARAM_INVALID;
}
const gert::StorageShape *cur_shape = context->GetInputShape(index1);
bool is_dim4 = true;
const size_t kDim3 = 3;
const size_t kDim4 = 4;
if (cur_shape->GetStorageShape().GetDimNum() == kDim4) {
is_dim4 = true;
} else if (cur_shape->GetStorageShape().GetDimNum() == kDim3) {
is_dim4 = false;
} else {
return ge::GRAPH_PARAM_INVALID;
}
int64_t b = cur_shape->GetStorageShape().GetDim(index0);
int64_t h = 0;
int64_t us = 0;
// s need get when run
int64_t d = 0;
if (is_dim4) {
// (b, h, us, d) -> (bs, us, d)
h = cur_shape->GetStorageShape().GetDim(index1);
us = cur_shape->GetStorageShape().GetDim(index2);
d = cur_shape->GetStorageShape().GetDim(index3);
} else {
h = 1;
us = cur_shape->GetStorageShape().GetDim(index1);
d = cur_shape->GetStorageShape().GetDim(index2);
}
auto platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
auto aiv_num = platform.GetCoreNumAiv();
context->SetBlockDim(aiv_num);
// set workspace for 910B
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
currentWorkspace[0] = k910bWS;
TilingData tiling;
tiling.set_core_num(aiv_num);
tiling.set_b(b);
tiling.set_h(h);
tiling.set_d(d);
tiling.set_us(us);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckSupported(const ge::Operator &op, ge::AscendString &result) {
std::string resultStr{};
constexpr size_t input_num = 7;
if (op.GetInputsSize() != input_num) {
resultStr = R"({"ret_code": "1", "reason": "input num is not 7"})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
if (op.GetOutputsSize() != 1) {
resultStr = R"({"ret_code": "1", "reason": "output num is not 1"})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
for (size_t i = 0; i < input_num; ++i) {
if (op.GetInputDesc(i).GetFormat() != ge::FORMAT_ND) {
resultStr = R"({"ret_code": "1", "reason": "input format is not supported, only support ND."})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
}
const int64_t input_dim3_num = 3;
const int64_t input_dim4_num = 4;
if (op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim3_num &&
op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim4_num) {
resultStr = R"({"ret_code": "1", "reason": "input dim is not supported, cache and update dim must be 4."})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
if (op.GetInputDesc(index2).GetShape().GetDimNum() != 1 || op.GetInputDesc(index3).GetShape().GetDimNum() != 1 ||
op.GetInputDesc(index4).GetShape().GetDimNum() != 1 || op.GetInputDesc(index5).GetShape().GetDimNum() != 1 ||
op.GetInputDesc(index6).GetShape().GetDimNum() != 1) {
resultStr =
R"({"ret_code": "1", "reason": "input dim is not supported, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len and cur_max_seq_len dim must be 1."})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
} // namespace optiling
namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
const gert::Shape *x1_shape = context->GetInputShape(0);
gert::Shape *y_shape = context->GetOutputShape(0);
*y_shape = *x1_shape;
return GRAPH_SUCCESS;
}
ge::graphStatus InferDecoderKvCacheDataType(gert::InferDataTypeContext *context) {
const ge::DataType datatype = context->GetInputDataType(0);
context->SetOutputDataType(0, datatype);
return GRAPH_SUCCESS;
}
} // namespace ge
namespace ops {
class DecoderKvCache : public OpDef {
public:
explicit DecoderKvCache(const char *name) : OpDef(name) {
this->Input("cache")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("update")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("valid_seq_len")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("batch_index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("seq_len_axis")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("new_max_seq_len")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("cur_max_seq_len")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->SetInferShape(ge::InferShape);
this->SetInferDataType(ge::InferDecoderKvCacheDataType);
this->AICore()
.SetTiling(optiling::TilingFunc)
.AddConfig("ascend910")
.AddConfig("ascend910b")
.SetCheckSupport(optiling::CheckSupported);
}
};
OP_ADD(DecoderKvCache);
} // namespace ops

View File

@ -0,0 +1,31 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DECODER_KV_CACHE_TILING_H
#define DECODER_KV_CACHE_TILING_H
#include "register/tilingdata_base.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(TilingData)
TILING_DATA_FIELD_DEF(int64_t, core_num);
TILING_DATA_FIELD_DEF(int64_t, b);
TILING_DATA_FIELD_DEF(int64_t, h);
TILING_DATA_FIELD_DEF(int64_t, d);
TILING_DATA_FIELD_DEF(int64_t, us);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(DecoderKvCache, TilingData)
} // namespace optiling
#endif // DECODER_KV_CACHE_TILING_H

View File

@ -0,0 +1,77 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PLATFORM_ASCENDC_H
#define PLATFORM_ASCENDC_H
#include <cstdint>
namespace fe {
class PlatFormInfos;
}
namespace platform_ascendc {
enum class CoreMemType { L0_A = 0, L0_B = 1, L0_C = 2, L1 = 3, L2 = 4, UB = 5, HBM = 6, RESERVED };
enum class SocVersion { ASCEND910 = 0, ASCEND910B, ASCEND310P, RESERVED_VERSION = 99999 };
class PlatformAscendC {
public:
PlatformAscendC() = delete;
~PlatformAscendC() {}
explicit PlatformAscendC(fe::PlatFormInfos *platformInfo) : platformInfo_(platformInfo) {}
/**
* Get Core Number
* On Ascend910B MIX model, return AICore number
* @return core number by core type
*/
uint32_t GetCoreNum(void) const;
/**
* Get Core Number AiCore
* @return ai_core_num
*/
uint32_t GetCoreNumAic(void) const;
/**
* Get Core Number VectorCore
* @return vector_core_num
*/
uint32_t GetCoreNumAiv(void) const;
/**
* Calc task schedule block dim
* @sliceNum number slice of data division
* @aicCoreNum value of GetCoreNumAic() if used cube API, otherwise 0
* @aivCoreNum value of GetCoreNumAiv() if used vector API, otherwise 0
* @return task schedule block dim
*/
uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) const;
/**
* Get Work Space Size
* @return work sapce size by chip type
*/
uint32_t GetLibApiWorkSpaceSize(void) const;
void GetCoreMemSize(const CoreMemType &memType, uint64_t &size) const;
void GetCoreMemBw(const CoreMemType &memType, uint64_t &bwSize) const;
/**
* Get Soc Version Enum
* @return Enum SocVersion
*/
SocVersion GetSocVersion(void) const;
private:
fe::PlatFormInfos *platformInfo_;
};
} // namespace platform_ascendc
#endif

View File

@ -0,0 +1,247 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "prompt_kv_cache_tiling.h"
#include "register/op_def_registry.h"
#include "platform_ascendc.h"
namespace {
constexpr int index0 = 0;
constexpr int index1 = 1;
constexpr int index2 = 2;
constexpr int index3 = 3;
constexpr int index4 = 4;
constexpr int index5 = 5;
constexpr int index6 = 6;
constexpr int32_t kSize1 = 1;
constexpr int32_t kSize2 = 2;
constexpr int32_t kSize4 = 4;
constexpr int64_t kAxisOne = 1;
constexpr int64_t kAxisTwo = 2;
constexpr size_t k910bWS = 16 * 1024 * 1024;
} // namespace
namespace optiling {
static ge::graphStatus TilingFunc(gert::TilingContext *context) {
auto past_input = context->GetInputDesc(index0);
auto dtype = past_input->GetDataType();
auto type_size = ge::GetSizeByDataType(dtype);
switch (type_size) {
case kSize1:
context->SetTilingKey(kSize1);
break;
case kSize2:
context->SetTilingKey(kSize2);
break;
case kSize4:
context->SetTilingKey(kSize4);
break;
default:
return ge::GRAPH_PARAM_INVALID;
}
const gert::StorageShape *cache_shape = context->GetInputShape(index0);
const gert::StorageShape *update_shape = context->GetInputShape(index1);
bool is_dim4 = true;
const size_t kDim3 = 3;
const size_t kDim4 = 4;
if (cache_shape->GetStorageShape().GetDimNum() == kDim4) {
is_dim4 = true;
} else if (cache_shape->GetStorageShape().GetDimNum() == kDim3) {
is_dim4 = false;
} else {
return ge::GRAPH_PARAM_INVALID;
}
int64_t b = cache_shape->GetStorageShape().GetDim(index0);
int64_t ub = update_shape->GetStorageShape().GetDim(index0);
// s need get when run
int64_t h = 0;
int64_t s = 0;
int64_t us = 0;
int64_t d = 0;
if (is_dim4) {
h = update_shape->GetStorageShape().GetDim(index1);
s = cache_shape->GetStorageShape().GetDim(index2);
us = update_shape->GetStorageShape().GetDim(index2);
d = update_shape->GetStorageShape().GetDim(index3);
} else {
h = 1;
s = cache_shape->GetStorageShape().GetDim(index1);
us = update_shape->GetStorageShape().GetDim(index1);
d = update_shape->GetStorageShape().GetDim(index2);
}
auto platform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
auto aiv_num = platform.GetCoreNumAiv();
context->SetBlockDim(aiv_num);
// set workspace for 910B
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
currentWorkspace[0] = k910bWS;
TilingData tiling;
tiling.set_core_num(aiv_num);
tiling.set_b(b);
tiling.set_h(h);
tiling.set_s(s);
tiling.set_d(d);
tiling.set_ub(ub);
tiling.set_us(us);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus CheckSupported(const ge::Operator &op, ge::AscendString &result) {
std::string resultStr{};
constexpr size_t input_num = 7;
if (op.GetInputsSize() != input_num) {
resultStr = R"({"ret_code": "1", "reason": "input num is not 7"})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
if (op.GetOutputsSize() != 1) {
resultStr = R"({"ret_code": "1", "reason": "output num is not 1"})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
for (size_t i = 0; i < input_num; ++i) {
if (op.GetInputDesc(i).GetFormat() != ge::FORMAT_ND) {
resultStr = R"({"ret_code": "1", "reason": "input format is not supported, only support ND."})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
}
const int64_t input_dim3_num = 3;
const int64_t input_dim4_num = 4;
if (op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim3_num &&
op.GetInputDesc(index0).GetShape().GetDimNum() != input_dim4_num) {
resultStr = R"({"ret_code": "1", "reason": "input dim is not supported, cache and update dim must be 4."})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
if (op.GetInputDesc(index1).GetShape().GetDimNum() != input_dim3_num &&
op.GetInputDesc(index1).GetShape().GetDimNum() != input_dim4_num) {
resultStr = R"({"ret_code": "1", "reason": "input dim is not supported, cache and update dim must be 4."})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
if (op.GetInputDesc(index2).GetShape().GetDimNum() != 1 || op.GetInputDesc(index3).GetShape().GetDimNum() != 1 ||
op.GetInputDesc(index4).GetShape().GetDimNum() != 1 || op.GetInputDesc(index5).GetShape().GetDimNum() != 1 ||
op.GetInputDesc(index6).GetShape().GetDimNum() != 1) {
resultStr =
R"({"ret_code": "1", "reason": "input dim is not supported, valid_seq_len, batch_index, seq_len_axis,
new_max_seq_len and cur_max_seq_len dim must be 1."})";
result = ge::AscendString(resultStr.c_str());
return ge::GRAPH_FAILED;
}
return ge::GRAPH_SUCCESS;
}
} // namespace optiling
namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext *context) {
const gert::Shape *x1_shape = context->GetInputShape(0);
gert::Shape *y_shape = context->GetOutputShape(0);
*y_shape = *x1_shape;
return GRAPH_SUCCESS;
}
ge::graphStatus InferPromptKvCacheDataType(gert::InferDataTypeContext *context) {
const ge::DataType datatype = context->GetInputDataType(0);
context->SetOutputDataType(0, datatype);
return GRAPH_SUCCESS;
}
} // namespace ge
namespace ops {
class PromptKvCache : public OpDef {
public:
explicit PromptKvCache(const char *name) : OpDef(name) {
this->Input("cache")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("update")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("valid_seq_len")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("batch_index")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("seq_len_axis")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("new_max_seq_len")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("cur_max_seq_len")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_UINT8, ge::DT_UINT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat(
{ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->SetInferShape(ge::InferShape);
this->SetInferDataType(ge::InferPromptKvCacheDataType);
this->AICore()
.SetTiling(optiling::TilingFunc)
.AddConfig("ascend910")
.AddConfig("ascend910b")
.SetCheckSupport(optiling::CheckSupported);
}
};
OP_ADD(PromptKvCache);
} // namespace ops

View File

@ -0,0 +1,34 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PROMPT_KV_CACHE_TILING_H
#define PROMPT_KV_CACHE_TILING_H
#include "register/tilingdata_base.h"
namespace optiling {
BEGIN_TILING_DATA_DEF(TilingData)
TILING_DATA_FIELD_DEF(int64_t, core_num);
TILING_DATA_FIELD_DEF(int64_t, b);
TILING_DATA_FIELD_DEF(int64_t, h);
TILING_DATA_FIELD_DEF(int64_t, s);
TILING_DATA_FIELD_DEF(int64_t, d);
TILING_DATA_FIELD_DEF(int64_t, ub);
TILING_DATA_FIELD_DEF(int64_t, us);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(PromptKvCache, TilingData)
} // namespace optiling
#endif // PROMPT_KV_CACHE_TILING_H

View File

@ -0,0 +1,196 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
using namespace AscendC;
namespace {
constexpr int32_t kBufferNum = 2;
constexpr int64_t kUbSize = 192 * 1024;
const int64_t kDivisor = 4;
static __aicore__ inline int64_t CeilRound(int64_t value, int64_t kDivisor) {
if (kDivisor == 0) {
return 0;
}
return (value + kDivisor - 1) / kDivisor * kDivisor;
}
} // namespace
template <typename T>
class KernelDecoderKvCache {
public:
__aicore__ inline KernelDecoderKvCache() {}
__aicore__ inline void GetNewMaxSeqLen(GM_ADDR new_max_seq_len) {
new_max_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)new_max_seq_len, 4);
pipe_.InitBuffer(new_max_seq_len_queue_, 1, CeilRound(1, kDivisor) * sizeof(int64_t));
LocalTensor<int64_t> new_max_seq_len_tensor = new_max_seq_len_queue_.AllocTensor<int64_t>();
pipe_barrier((pipe_t)PIPE_ALL);
DataCopy(new_max_seq_len_tensor, new_max_seq_len_gm_, CeilRound(1, kDivisor));
pipe_barrier((pipe_t)PIPE_ALL);
s_ = new_max_seq_len_tensor.GetValue(0);
new_max_seq_len_queue_.FreeTensor(new_max_seq_len_tensor);
}
__aicore__ inline void GetValidSeqLen(GM_ADDR valid_seq_len, int64_t ub) {
int64_t valid_seq_len_ub_size = CeilRound(ub, kDivisor);
valid_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)valid_seq_len, valid_seq_len_ub_size);
pipe_.InitBuffer(valid_seq_len_queue_, 1, valid_seq_len_ub_size * sizeof(int64_t));
valid_seq_len_tensor_ = valid_seq_len_queue_.AllocTensor<int64_t>();
pipe_barrier((pipe_t)PIPE_ALL);
DataCopy(valid_seq_len_tensor_, valid_seq_len_gm_, valid_seq_len_ub_size);
pipe_barrier((pipe_t)PIPE_ALL);
remain_ub_size_ -= valid_seq_len_ub_size * sizeof(int64_t);
}
__aicore__ inline void SplitBh(int64_t bh) {
split_bh_ = 1;
former_block_bh_ = bh;
while (kBufferNum * former_block_bh_ * us_ * d_ * sizeof(T) >= remain_ub_size_) {
split_bh_++;
former_block_bh_ = (bh + split_bh_ - 1) / split_bh_;
}
tail_block_bh_ = bh - (split_bh_ - 1) * former_block_bh_;
}
__aicore__ inline void Update(GM_ADDR cache, GM_ADDR update, LocalTensor<T> update_in_local_tensor) {
for (int64_t i = 0; i < split_bh_; i++) {
int64_t block_bh;
if (i != split_bh_ - 1) {
block_bh = former_block_bh_;
} else {
block_bh = tail_block_bh_;
}
update_gm_.SetGlobalBuffer((__gm__ T *)update + core_idx_ * update_core_stride_ + i * former_block_bh_ * us_ * d_,
block_bh * us_ * d_);
DataCopy(update_in_local_tensor, update_gm_, block_bh * us_ * d_);
pipe_barrier((pipe_t)PIPE_ALL);
update_queue_.EnQue(update_in_local_tensor);
LocalTensor<T> update_in_local_tensor_out = update_queue_.DeQue<T>();
for (int64_t j = 0; j < block_bh; j++) {
int64_t bh_idx = core_idx_ * former_bh_ + i * former_block_bh_ + j;
auto b_idx = bh_idx / h_;
pipe_barrier((pipe_t)PIPE_ALL);
auto s_idx = valid_seq_len_tensor_.GetValue(b_idx);
pipe_barrier((pipe_t)PIPE_ALL);
if (s_idx < 0 || s_idx >= s_) {
continue;
}
out_gm_.SetGlobalBuffer((__gm__ T *)cache + bh_idx * s_ * d_ + s_idx * d_, us_ * d_);
int64_t src_offset = j * us_ * d_;
pipe_barrier((pipe_t)PIPE_ALL);
DataCopy(out_gm_, update_in_local_tensor_out[src_offset], us_ * d_);
}
}
}
__aicore__ inline void Process(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len, GM_ADDR batch_index,
GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len, int64_t core_num, int64_t b,
int64_t h, int64_t d, int64_t us) {
core_idx_ = GetBlockIdx();
former_bh_ = (b * h + core_num - 1) / core_num;
core_num_ = (b * h + former_bh_ - 1) / former_bh_;
if (core_idx_ >= core_num_) {
return;
}
tail_bh_ = b * h - (core_num_ - 1) * former_bh_;
b_ = b;
h_ = h;
d_ = d;
us_ = us;
GetNewMaxSeqLen(new_max_seq_len);
GetValidSeqLen(valid_seq_len, b);
update_core_stride_ = former_bh_ * us_ * d_;
cache_core_stride_ = former_bh_ * s_ * d_;
if (core_idx_ != core_num - 1) {
SplitBh(former_bh_);
} else {
SplitBh(tail_bh_);
}
pipe_.InitBuffer(update_queue_, kBufferNum, former_block_bh_ * us_ * d_ * sizeof(T));
LocalTensor<T> update_in_local_tensor = update_queue_.AllocTensor<T>();
Update(cache, update, update_in_local_tensor);
valid_seq_len_queue_.FreeTensor(valid_seq_len_tensor_);
update_queue_.FreeTensor(update_in_local_tensor);
}
private:
// gm
GlobalTensor<T> update_gm_;
GlobalTensor<int64_t> valid_seq_len_gm_;
GlobalTensor<int64_t> new_max_seq_len_gm_;
GlobalTensor<T> out_gm_;
// local
LocalTensor<int64_t> valid_seq_len_tensor_;
TPipe pipe_;
// create queues for input, in this case depth is equal to buffer num
TQue<QuePosition::VECIN, 1> update_queue_;
TQue<QuePosition::VECIN, 1> valid_seq_len_queue_;
TQue<QuePosition::VECIN, 1> new_max_seq_len_queue_;
int64_t remain_ub_size_ = kUbSize;
int64_t split_bh_ = 0;
int64_t former_block_bh_ = 0;
int64_t tail_block_bh_ = 0;
int64_t core_idx_ = 0;
int64_t cache_core_stride_ = 0;
int64_t cache_block_length = 0;
int64_t update_core_stride_ = 0;
int64_t update_block_length_ = 0;
int64_t core_num_ = 0;
int64_t former_bh_ = 0;
int64_t tail_bh_ = 0;
int64_t b_ = 0;
int64_t h_ = 0;
int64_t s_ = 0;
int64_t d_ = 0;
int64_t us_ = 0;
};
extern "C" __global__ __aicore__ void decoder_kv_cache(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len,
GM_ADDR batch_index, GM_ADDR seq_len_axis,
GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len, GM_ADDR out,
GM_ADDR workspace, GM_ADDR tiling) {
GET_TILING_DATA(tiling_data, tiling);
if (TILING_KEY_IS(1)) {
KernelDecoderKvCache<int8_t> op;
op.Process(cache, update, valid_seq_len, batch_index, new_max_seq_len, cur_max_seq_len, tiling_data.core_num,
tiling_data.b, tiling_data.h, tiling_data.d, tiling_data.us);
} else if (TILING_KEY_IS(2)) {
KernelDecoderKvCache<int16_t> op;
op.Process(cache, update, valid_seq_len, batch_index, new_max_seq_len, cur_max_seq_len, tiling_data.core_num,
tiling_data.b, tiling_data.h, tiling_data.d, tiling_data.us);
} else if (TILING_KEY_IS(4)) {
KernelDecoderKvCache<int32_t> op;
op.Process(cache, update, valid_seq_len, batch_index, new_max_seq_len, cur_max_seq_len, tiling_data.core_num,
tiling_data.b, tiling_data.h, tiling_data.d, tiling_data.us);
}
}

View File

@ -0,0 +1,217 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "kernel_utils.h"
using namespace AscendC;
namespace {
constexpr int64_t kBufferNum = 2;
constexpr int64_t kUbSize = 192 * 1024;
const int64_t kDivisor = 4;
static __aicore__ inline int64_t CeilRound(int64_t value, int64_t kDivisor) {
if (kDivisor == 0) {
return 0;
}
return (value + kDivisor - 1) / kDivisor * kDivisor;
}
} // namespace
template <typename T>
class KernelPromptKvCache {
public:
__aicore__ inline KernelPromptKvCache() {}
__aicore__ inline void GetIndex(GM_ADDR batch_index, GM_ADDR valid_seq_len, int64_t ub) {
int64_t batch_index_ub_num = CeilRound(ub, kDivisor);
int64_t valid_seq_len_ub_num = CeilRound(ub, kDivisor);
batch_index_gm_.SetGlobalBuffer((__gm__ int64_t *)batch_index, batch_index_ub_num);
int64_t total_num = batch_index_ub_num + valid_seq_len_ub_num;
pipe_.InitBuffer(index_queue_, 1, total_num * sizeof(int64_t));
batch_index_tensor_ = index_queue_.AllocTensor<int64_t>();
DataCopy(batch_index_tensor_, batch_index_gm_, batch_index_ub_num);
valid_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)valid_seq_len, valid_seq_len_ub_num);
valid_seq_len_tensor_ = batch_index_tensor_[batch_index_ub_num];
DataCopy(valid_seq_len_tensor_, valid_seq_len_gm_, valid_seq_len_ub_num);
remain_ub_size_ -= total_num * sizeof(int64_t);
}
__aicore__ inline void GetNewMaxSeqLen(GM_ADDR new_max_seq_len) {
new_max_seq_len_gm_.SetGlobalBuffer((__gm__ int64_t *)new_max_seq_len, 4);
pipe_.InitBuffer(new_max_seq_len_queue_, 1, CeilRound(1, kDivisor) * sizeof(int64_t));
LocalTensor<int64_t> new_max_seq_len_tensor = new_max_seq_len_queue_.AllocTensor<int64_t>();
DataCopy(new_max_seq_len_tensor, new_max_seq_len_gm_, CeilRound(1, kDivisor));
pipe_barrier((pipe_t)PIPE_ALL);
s_ = new_max_seq_len_tensor.GetValue(0);
new_max_seq_len_queue_.FreeTensor(new_max_seq_len_tensor);
}
__aicore__ inline void UpdateCache(GM_ADDR cache, GM_ADDR update) {
int64_t split_us = 1;
int64_t block_us = us_ / split_us;
while (kBufferNum * block_us * d_ * sizeof(T) >= remain_ub_size_) {
split_us++;
block_us = (us_ + split_us - 1) / split_us;
}
int64_t former_block_us = block_us;
int64_t tail_block_us = us_ - (split_us - 1) * former_block_us;
pipe_.InitBuffer(update_queue_, kBufferNum, block_us * d_ * sizeof(T));
for (int64_t i = 0; i < each_core_bs_num_; ++i) {
int64_t bh_idx = core_idx_ * former_each_core_bs_num_ + i;
int64_t ub_idx = bh_idx / h_;
int64_t h_idx = bh_idx % h_;
pipe_barrier((pipe_t)PIPE_ALL);
int64_t cache_b_idx = batch_index_tensor_.GetValue(ub_idx);
int64_t s_idx = valid_seq_len_tensor_.GetValue(ub_idx);
if (cache_b_idx < 0 || cache_b_idx >= b_) {
continue;
}
if (s_idx < 0 || s_idx + us_ > s_) {
continue;
}
for (int64_t j = 0; j < split_us; ++j) {
int64_t u_block_len;
if (j == split_us - 1) {
u_block_len = tail_block_us * d_;
} else {
u_block_len = former_block_us * d_;
}
LocalTensor<T> update_in_local_tensor = update_queue_.AllocTensor<T>();
update_gm_.SetGlobalBuffer(
(__gm__ T *)update + ub_idx * update_b_stride_ + h_idx * update_h_stride_ + j * former_block_us * d_,
u_block_len);
out_gm_.SetGlobalBuffer((__gm__ T *)cache + cache_b_idx * cache_b_stride_ + h_idx * cache_h_stride_ +
s_idx * d_ + j * former_block_us * d_,
u_block_len);
pipe_barrier((pipe_t)PIPE_ALL);
DataCopy(update_in_local_tensor, update_gm_, u_block_len);
update_queue_.EnQue(update_in_local_tensor);
LocalTensor<T> update_in_local_tensor_out = update_queue_.DeQue<T>();
pipe_barrier((pipe_t)PIPE_ALL);
DataCopy(out_gm_, update_in_local_tensor_out, u_block_len);
update_queue_.FreeTensor(update_in_local_tensor_out);
}
}
}
__aicore__ inline void Process(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len, GM_ADDR batch_index,
GM_ADDR seq_len_axis, GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len,
int64_t max_core_num, int64_t b, int64_t h, int64_t s, int64_t d, int64_t ub,
int64_t us) {
core_idx_ = GetBlockIdx();
int64_t bs = ub * h;
former_each_core_bs_num_ = (bs + max_core_num - 1) / max_core_num;
core_num_ = (bs + former_each_core_bs_num_ - 1) / former_each_core_bs_num_;
tail_each_core_bs_num_ = bs - (core_num_ - 1) * former_each_core_bs_num_;
if (core_idx_ >= core_num_) {
return;
}
if (g_coreType == AIC) {
return;
}
GetIndex(batch_index, valid_seq_len, ub);
GetNewMaxSeqLen(new_max_seq_len);
b_ = b * s / s_;
h_ = h;
d_ = d;
ub_ = ub;
us_ = us;
if (core_idx_ != core_num_ - 1) {
each_core_bs_num_ = former_each_core_bs_num_;
} else {
each_core_bs_num_ = tail_each_core_bs_num_;
}
cache_h_stride_ = s_ * d_;
cache_b_stride_ = h_ * cache_h_stride_;
update_h_stride_ = us_ * d_;
update_b_stride_ = h_ * update_h_stride_;
UpdateCache(cache, update);
index_queue_.FreeTensor(batch_index_tensor_);
}
private:
// gm
GlobalTensor<T> update_gm_;
GlobalTensor<int64_t> valid_seq_len_gm_;
GlobalTensor<int64_t> batch_index_gm_;
GlobalTensor<int64_t> new_max_seq_len_gm_;
GlobalTensor<T> out_gm_;
// local gm
LocalTensor<int64_t> valid_seq_len_tensor_;
LocalTensor<int64_t> batch_index_tensor_;
TPipe pipe_;
TQue<QuePosition::VECIN, 1> update_queue_;
TQue<QuePosition::VECIN, 1> index_queue_;
TQue<QuePosition::VECIN, 1> new_max_seq_len_queue_;
int64_t remain_ub_size_ = kUbSize;
int64_t core_idx_ = 0;
int64_t core_num_ = 0;
int64_t each_core_bs_num_ = 0;
int64_t former_each_core_bs_num_ = 0;
int64_t tail_each_core_bs_num_ = 0;
int64_t b_ = 0;
int64_t h_ = 0;
int64_t s_ = 0;
int64_t d_ = 0;
int64_t ub_ = 0;
int64_t us_ = 0;
int64_t ps_ = 0;
int64_t cache_b_stride_ = 0;
int64_t cache_h_stride_ = 0;
int64_t update_b_stride_ = 0;
int64_t update_h_stride_ = 0;
};
extern "C" __global__ __aicore__ void prompt_kv_cache(GM_ADDR cache, GM_ADDR update, GM_ADDR valid_seq_len,
GM_ADDR batch_index, GM_ADDR seq_len_axis,
GM_ADDR new_max_seq_len, GM_ADDR cur_max_seq_len, GM_ADDR out,
GM_ADDR workspace, GM_ADDR tiling) {
GET_TILING_DATA(tiling_data, tiling);
if (TILING_KEY_IS(1)) {
KernelPromptKvCache<int8_t> op;
op.Process(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
tiling_data.core_num, tiling_data.b, tiling_data.h, tiling_data.s, tiling_data.d, tiling_data.ub,
tiling_data.us);
} else if (TILING_KEY_IS(2)) {
KernelPromptKvCache<int16_t> op;
op.Process(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
tiling_data.core_num, tiling_data.b, tiling_data.h, tiling_data.s, tiling_data.d, tiling_data.ub,
tiling_data.us);
} else if (TILING_KEY_IS(4)) {
KernelPromptKvCache<int32_t> op;
op.Process(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len,
tiling_data.core_num, tiling_data.b, tiling_data.h, tiling_data.s, tiling_data.d, tiling_data.ub,
tiling_data.us);
}
}

View File

@ -33,9 +33,9 @@ class DecoderKVCacheNet(Cell):
new_max_seq_len, cur_max_seq_len)
def test_decoder_k_v_cache_net():
def test_decoder_k_v_cache_net_4dims():
"""
Feature: test decoder_k_v_cache auto parallel
Feature: test decoder_k_v_cache auto parallel 4dims
Description: auto parallel
Expectation: shape is as expected.
"""
@ -56,3 +56,28 @@ def test_decoder_k_v_cache_net():
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
assert validator.check_parameter_shape('valid_seq_len', [1])
def test_decoder_k_v_cache_net_3dims():
"""
Feature: test decoder_k_v_cache auto parallel 3dims
Description: auto parallel
Expectation: shape is as expected.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((4, 1, 1,), (4, 1, 1,), (4,), (1,), (1,), (1,), (1,))
net = DecoderKVCacheNet(strategy)
cache = Parameter(Tensor(np.ones([4, 1024, 5120]), dtype=ms.float16), "cache")
update = Parameter(Tensor(np.ones([4, 1, 5120]), dtype=ms.float16), "update")
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
validator = ParallelValidator(net, phase)
assert validator.check_parameter_shape('cache', [1, 1024, 5120])
assert validator.check_parameter_shape('update', [1, 1, 5120])
assert validator.check_parameter_shape('valid_seq_len', [1])

View File

@ -33,9 +33,9 @@ class PromptKVCacheNet(Cell):
new_max_seq_len, cur_max_seq_len)
def test_prompt_k_v_cache_net():
def test_prompt_k_v_cache_net_dim4():
"""
Feature: test prompt_k_v_cache auto parallel
Feature: test prompt_k_v_cache auto parallel 4 dims
Description: auto parallel
Expectation: shape is as expected.
"""
@ -57,3 +57,29 @@ def test_prompt_k_v_cache_net():
assert validator.check_parameter_shape('cache', [1, 40, 1024, 128])
assert validator.check_parameter_shape('update', [1, 40, 1, 128])
assert validator.check_parameter_shape('valid_seq_len', [1])
def test_prompt_k_v_cache_net_dim3():
"""
Feature: test prompt_k_v_cache auto parallel 4 dims
Description: auto parallel
Expectation: shape is as expected.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((4, 1, 1,), (4, 1, 1,), (4,), (1,), (1,), (1,), (1,))
padding_mode = "right"
net = PromptKVCacheNet(padding_mode, strategy)
cache = Parameter(Tensor(np.ones([4, 1024, 5120]), dtype=ms.float16), "cache")
update = Parameter(Tensor(np.ones([4, 1, 5120]), dtype=ms.float16), "update")
valid_seq_len = Parameter(Tensor(np.ones([4]), dtype=ms.int64), "valid_seq_len")
batch_index = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "batch_index")
seq_len_axis = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "seq_len_axis")
new_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "new_max_seq_len")
cur_max_seq_len = Parameter(Tensor(np.ones([1]), dtype=ms.int64), "cur_max_seq_len")
net.set_inputs(cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
phase = compile_net(net, cache, update, valid_seq_len, batch_index, seq_len_axis, new_max_seq_len, cur_max_seq_len)
validator = ParallelValidator(net, phase)
assert validator.check_parameter_shape('cache', [1, 1024, 5120])
assert validator.check_parameter_shape('update', [1, 1, 5120])
assert validator.check_parameter_shape('valid_seq_len', [1])