forked from mindspore-Ecosystem/mindspore
code review part8
This commit is contained in:
parent
66e6d3c93e
commit
362f26ab9b
|
@ -38,12 +38,18 @@ constexpr size_t kBottom = 1;
|
|||
constexpr size_t kLeft = 2;
|
||||
constexpr size_t kRight = 3;
|
||||
constexpr size_t kPadDims = 4;
|
||||
constexpr int kPadElementNum = 8;
|
||||
|
||||
void ReplaceParamsAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &pad_cnode,
|
||||
const std::string &pattern_name) {
|
||||
auto paddings = pad_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
|
||||
MS_ASSERT(paddings != nullptr);
|
||||
MS_ASSERT(paddings->default_param() != nullptr);
|
||||
auto pad_list = std::dynamic_pointer_cast<tensor::Tensor>(paddings->default_param());
|
||||
MS_ASSERT(pad_list != nullptr);
|
||||
MS_ASSERT(pad_list->ElementsNum() == kPadElementNum);
|
||||
auto pad_data = static_cast<int32_t *>(pad_list->data_c());
|
||||
MS_ASSERT(pad_data != nullptr);
|
||||
|
||||
std::vector<int64_t> pad_list_data;
|
||||
if (pattern_name == "PadConvPatternName") {
|
||||
|
@ -109,6 +115,21 @@ bool IsPrimitiveProper(const CNodePtr &conv_cnode, const CNodePtr &pad_cnode) {
|
|||
if (!utils::isa<Parameter>(pad_cnode->input(kInputIndexTwo))) {
|
||||
return false;
|
||||
}
|
||||
auto pad_list = pad_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
|
||||
auto tensor_param = pad_list->default_param();
|
||||
if (tensor_param == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto tensor = tensor_param->cast<tensor::TensorPtr>();
|
||||
if (tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (tensor->data_type() != kNumberTypeInt32 && tensor->data_type() != kNumberTypeInt) {
|
||||
return false;
|
||||
}
|
||||
if (tensor->data_c() == nullptr || tensor->ElementsNum() != kPadElementNum) {
|
||||
return false;
|
||||
}
|
||||
auto pad_primitive = GetValueNode<std::shared_ptr<ops::PadFusion>>(pad_cnode->input(0));
|
||||
MS_ASSERT(pad_primitive != nullptr);
|
||||
if (!pad_primitive->HasAttr(ops::kPaddingMode)) {
|
||||
|
|
|
@ -217,47 +217,10 @@ CNodePtr TfliteRelPosMultiHeadAttentionFusion::CreateRelPosMultiHeadAttentionNod
|
|||
MS_ASSERT(func_graph != nullptr && equiv != nullptr);
|
||||
auto attention_prim = BuildAttentionPrim(equiv);
|
||||
MS_CHECK_TRUE_RET(attention_prim != nullptr, nullptr);
|
||||
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSize, kOutputSize);
|
||||
MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
|
||||
auto query_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[query_prim_]));
|
||||
MS_CHECK_TRUE_RET(query_prim != nullptr, nullptr);
|
||||
auto query_quant_param_holder = query_prim->GetAttr("quant_params");
|
||||
if (query_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
kWeightQueryIndex, query_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
if (SetQuantParamForAttentionNode(attention_prim, equiv) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "set quant param for attehtion node failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto key_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[key_prim_]));
|
||||
MS_CHECK_TRUE_RET(key_prim != nullptr, nullptr);
|
||||
auto key_quant_param_holder = key_prim->GetAttr("quant_params");
|
||||
if (key_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
kWeightKeyIndex, key_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
auto value_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[value_prim_]));
|
||||
MS_CHECK_TRUE_RET(value_prim != nullptr, nullptr);
|
||||
auto value_quant_param_holder = value_prim->GetAttr("quant_params");
|
||||
if (value_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
kWeightValueIndex, value_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
|
||||
auto pos_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[pos_prim_]));
|
||||
MS_CHECK_TRUE_RET(pos_prim != nullptr, nullptr);
|
||||
auto pos_quant_param_holder = pos_prim->GetAttr("quant_params");
|
||||
if (pos_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
kWeightPosIndex, pos_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
|
||||
auto output_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[output_prim_]));
|
||||
MS_CHECK_TRUE_RET(output_prim != nullptr, nullptr);
|
||||
auto output_quant_param_holder = output_prim->GetAttr("quant_params");
|
||||
if (output_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
kWeightOutputIndex, output_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
|
||||
attention_prim->AddAttr("quant_params", quant_params_holder);
|
||||
auto value_node = NewValueNode(attention_prim);
|
||||
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
|
||||
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
|
||||
|
@ -321,6 +284,69 @@ CNodePtr TfliteRelPosMultiHeadAttentionFusion::CreateRelPosMultiHeadAttentionNod
|
|||
return new_node;
|
||||
}
|
||||
|
||||
int TfliteRelPosMultiHeadAttentionFusion::SetQuantParamForAttentionNode(const PrimitivePtr &prim,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_ASSERT(prim != nullptr && equiv != nullptr);
|
||||
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSize, kOutputSize);
|
||||
MS_CHECK_TRUE_RET(quant_params_holder != nullptr, lite::RET_ERROR);
|
||||
auto query_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[query_prim_]));
|
||||
MS_CHECK_TRUE_RET(query_prim != nullptr, lite::RET_ERROR);
|
||||
auto query_quant_param_holder = query_prim->GetAttr("quant_params");
|
||||
if (query_quant_param_holder != nullptr) {
|
||||
auto query_quant_param = query_quant_param_holder->cast<lite::QuantParamHolderPtr>();
|
||||
MS_CHECK_TRUE_RET(query_quant_param != nullptr, lite::RET_ERROR);
|
||||
if (query_quant_param->get_input_quant_params().size() > 1) {
|
||||
quant_params_holder->set_input_quant_param(kWeightQueryIndex, query_quant_param->get_input_quant_params().at(1));
|
||||
}
|
||||
}
|
||||
auto key_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[key_prim_]));
|
||||
MS_CHECK_TRUE_RET(key_prim != nullptr, lite::RET_ERROR);
|
||||
auto key_quant_param_holder = key_prim->GetAttr("quant_params");
|
||||
if (key_quant_param_holder != nullptr) {
|
||||
auto key_quant_param = key_quant_param_holder->cast<lite::QuantParamHolderPtr>();
|
||||
MS_CHECK_TRUE_RET(key_quant_param != nullptr, lite::RET_ERROR);
|
||||
if (key_quant_param->get_input_quant_params().size() > 1) {
|
||||
quant_params_holder->set_input_quant_param(kWeightKeyIndex, key_quant_param->get_input_quant_params().at(1));
|
||||
}
|
||||
}
|
||||
auto value_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[value_prim_]));
|
||||
MS_CHECK_TRUE_RET(value_prim != nullptr, lite::RET_ERROR);
|
||||
auto value_quant_param_holder = value_prim->GetAttr("quant_params");
|
||||
if (value_quant_param_holder != nullptr) {
|
||||
auto value_quant_param = value_quant_param_holder->cast<lite::QuantParamHolderPtr>();
|
||||
MS_CHECK_TRUE_RET(value_quant_param != nullptr, lite::RET_ERROR);
|
||||
if (value_quant_param->get_input_quant_params().size() > 1) {
|
||||
quant_params_holder->set_input_quant_param(kWeightValueIndex, value_quant_param->get_input_quant_params().at(1));
|
||||
}
|
||||
}
|
||||
|
||||
auto pos_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[pos_prim_]));
|
||||
MS_CHECK_TRUE_RET(pos_prim != nullptr, lite::RET_ERROR);
|
||||
auto pos_quant_param_holder = pos_prim->GetAttr("quant_params");
|
||||
if (pos_quant_param_holder != nullptr) {
|
||||
auto pos_quant_param = pos_quant_param_holder->cast<lite::QuantParamHolderPtr>();
|
||||
MS_CHECK_TRUE_RET(pos_quant_param != nullptr, lite::RET_ERROR);
|
||||
if (pos_quant_param->get_input_quant_params().size() > 1) {
|
||||
quant_params_holder->set_input_quant_param(kWeightPosIndex, pos_quant_param->get_input_quant_params().at(1));
|
||||
}
|
||||
}
|
||||
|
||||
auto output_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[output_prim_]));
|
||||
MS_CHECK_TRUE_RET(output_prim != nullptr, lite::RET_ERROR);
|
||||
auto output_quant_param_holder = output_prim->GetAttr("quant_params");
|
||||
if (output_quant_param_holder != nullptr) {
|
||||
auto output_quant_param = output_quant_param_holder->cast<lite::QuantParamHolderPtr>();
|
||||
MS_CHECK_TRUE_RET(output_quant_param != nullptr, lite::RET_ERROR);
|
||||
if (output_quant_param->get_input_quant_params().size() > 1) {
|
||||
quant_params_holder->set_input_quant_param(kWeightOutputIndex,
|
||||
output_quant_param->get_input_quant_params().at(1));
|
||||
}
|
||||
}
|
||||
|
||||
prim->AddAttr("quant_params", quant_params_holder);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
const VectorRef TfliteRelPosMultiHeadAttentionFusion::DefineRelativeShiftPattern(const BaseRef &input) const {
|
||||
auto is_pad = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimPadFusion));
|
||||
MS_CHECK_TRUE_RET(is_pad != nullptr, {});
|
||||
|
|
|
@ -50,6 +50,9 @@ class TfliteRelPosMultiHeadAttentionFusion : public MultiHeadAttentionFusion {
|
|||
|
||||
CNodePtr CreateRelPosMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const std::string &base_name) const;
|
||||
|
||||
int SetQuantParamForAttentionNode(const PrimitivePtr &prim, const EquivPtr &equiv) const;
|
||||
|
||||
const VectorRef DefineRelativeShiftPattern(const BaseRef &input) const;
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue