!23842 add fullyconnected fusion

Merge pull request !23842 from wangyanling/r1.3
This commit is contained in:
i-robot 2021-09-30 09:03:04 +00:00 committed by Gitee
commit 21509c6d8d
8 changed files with 118 additions and 11 deletions

View File

@ -25,8 +25,9 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
return NNACL_PARAM_INVALID;
}
for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) {
int cmp_value = MSMIN(a_shape[i], b_shape[i]);
if (a_shape[i] != b_shape[i] && cmp_value != 1) {
int min_value = MSMIN(a_shape[i], b_shape[i]);
int max_value = MSMAX(a_shape[i], b_shape[i]);
if (max_value % min_value != 0) {
return NNACL_INPUT_TENSOR_ERROR;
}
}

View File

@ -116,15 +116,16 @@ int MatmulFP16CPUKernel::InitBroadcastParams() {
int out_batch = 1;
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
out_batch *= MSMAX(a_shape[i], b_shape[i]);
if (a_shape[i] < b_shape[i] && a_shape[i] == 1) {
if (a_shape[i] < b_shape[i] && b_shape[i] % a_shape[i] == 0) {
a_broadcast_ = true;
} else if (a_shape[i] > b_shape[i] && b_shape[i] == 1) {
} else if (a_shape[i] > b_shape[i] && a_shape[i] % b_shape[i] == 0) {
b_broadcast_ = true;
} else if (a_shape[i] != b_shape[i]) {
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
return RET_ERROR;
}
}
params_->batch = out_batch;
return RET_OK;
}

View File

@ -104,9 +104,9 @@ int MatmulCPUKernel::InitBroadcastParams() {
int out_batch = 1;
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
out_batch *= MSMAX(a_shape[i], b_shape[i]);
if (a_shape[i] < b_shape[i] && a_shape[i] == 1) {
if (a_shape[i] < b_shape[i] && b_shape[i] % a_shape[i] == 0) {
a_broadcast_ = true;
} else if (a_shape[i] > b_shape[i] && b_shape[i] == 1) {
} else if (a_shape[i] > b_shape[i] && a_shape[i] % b_shape[i] == 0) {
b_broadcast_ = true;
} else if (a_shape[i] != b_shape[i]) {
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;

View File

@ -106,3 +106,4 @@ hiai_nlu_model_single.pb;3;1,32:1,32:1,32
fsr_270_mindspore.pb
fsr_360_mindspore.pb
fsr_720_mindspore.pb
tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder.pb;14;4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640

View File

@ -91,3 +91,4 @@ hiai_nlu_model_single.pb;3;1,32:1,32:1,32 3.5
fsr_270_mindspore.pb 6.0
fsr_360_mindspore.pb 6.5
fsr_720_mindspore.pb 2.0
tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder.pb;14;4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 0.5

View File

@ -203,3 +203,7 @@ add_uint8.tflite;2
coco_ssd_mobilenet_v1_1.0.tflite
hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32.tflite;2
hiai_asr_ctc.tflite;2
# weight quant model
tt_raw_h4800_mel80_ms_fe001_ex_20210506_encoder.tflite;25;1,15,80:1,15,80:1,15,80:1,15,80:1,15,80:1,15,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,640 8.5
# weight quant model
tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder.tflite;14;4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640:4 10.5

View File

@ -1,3 +1,3 @@
mobilenet.tflite 0.5
Change_input_transformer_20200831_encoder_fp32.tflite 70
Change_input_transformer_20200831_encoder_fp32.tflite 85
Change_input_transformer_20200831_decoder_fp32.tflite 35

View File

@ -23,9 +23,11 @@
#include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "src/common/log_adapter.h"
namespace mindspore::opt {
namespace {
constexpr int64_t kFcRightInputDims = 3;
bool IsStackNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimStack);
@ -63,10 +65,14 @@ void *GetInputAddr(const AnfNodePtr &node, size_t input_index) {
}
STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPtr &rmatmul_input) {
MS_ASSERT(stack_node != nullptr);
MS_ASSERT(right_matmul_input != nullptr);
MS_ASSERT(rmatmul_input != nullptr);
auto joint_fullconnect_size = stack_node->inputs().size() - 1;
auto fc = stack_node->input(1)->cast<CNodePtr>();
auto fc_weight = fc->input(2)->cast<ParameterPtr>();
if (fc_weight == nullptr) {
MS_LOG(ERROR) << "fully-connected weight is null";
return RET_ERROR;
}
auto fc_weight_param = std::dynamic_pointer_cast<tensor::Tensor>(fc_weight->default_param());
auto tensor_size = fc_weight_param->Size();
auto rmatmul_input_shape = fc_weight_param->shape();
@ -157,12 +163,74 @@ const BaseRef BatchMatMulFusion::DefinePattern() const {
return VectorRef({pack_var, left_fullconnect_var, right_fullconnect_var, other_fullconnect_var});
}
bool ConnectTransposeConcat(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
MS_LOG(ERROR) << "cnode is null";
return false;
}
auto right_transpose_node = cnode->input(1);
auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
if (right_transpose_cnode == nullptr) {
MS_LOG(ERROR) << "cnode is null";
return false;
}
auto front_node = right_transpose_cnode->input(1);
auto front_cnode = front_node->cast<CNodePtr>();
if (front_cnode == nullptr) {
MS_LOG(ERROR) << "cnode is null";
return false;
}
if (CheckPrimitiveType(right_transpose_cnode, prim::kPrimTranspose) &&
(CheckPrimitiveType(front_cnode, prim::kPrimTranspose) || CheckPrimitiveType(front_cnode, prim::kPrimConcat))) {
return true;
}
return false;
}
int ResetReshapeParameters(const AnfNodePtr &reshape_node) {
auto reshape_cnode = reshape_node->cast<CNodePtr>();
MS_ASSERT(reshape_cnode != nullptr);
auto reshape_shape_param = reshape_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
MS_ASSERT(reshape_shape_param != nullptr);
auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(reshape_shape_param->default_param());
auto rmatmul_input_shape = shape_tensor->shape();
std::vector<int64_t> shape(1, 0);
if (rmatmul_input_shape.size() <= 0) {
MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
} else if (shape[0] < kFcRightInputDims) {
shape[0] = rmatmul_input_shape[0] + 1;
}
auto tensor_info = std::make_shared<tensor::Tensor>(shape_tensor->data_type(), shape);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
}
int *tensor_data = reinterpret_cast<int *>(tensor_info->data_c());
tensor_data[0] = 1;
int *reshape_data = reinterpret_cast<int *>(shape_tensor->data_c());
for (int64_t i = 1; i < shape[0]; ++i) {
tensor_data[i] = reshape_data[i - 1];
}
lite::InitParameterFromTensorInfo(reshape_shape_param, tensor_info);
return RET_OK;
}
// slice +fullconnect ->batchmatmul
const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
auto stack_cnode = node->cast<CNodePtr>();
if (stack_cnode == nullptr) {
MS_LOG(WARNING) << "stack cnode is null";
return nullptr;
}
// check stack node all inputs must fullconnect
for (size_t i = 1; i < stack_cnode->inputs().size(); i++) {
auto input_node = stack_cnode->input(i);
@ -172,10 +240,13 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
}
}
auto fullconnect_node = stack_cnode->input(1);
MS_ASSERT(fullconnnect_node != nullptr);
auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
auto left_slice_node = fullconnect_cnode->input(1);
if (left_slice_node == nullptr) {
MS_LOG(WARNING) << "slice node is null";
return nullptr;
}
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimSliceFusion)) {
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
@ -189,6 +260,10 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
}
auto left_matmul_input = left_slice_cnode->input(1);
auto right_reshape_node = fullconnect_cnode->input(2);
if (left_matmul_input == nullptr || right_reshape_node == nullptr) {
MS_LOG(ERROR) << "matmul's input is null";
return nullptr;
}
auto matmul_cvalue = BuildMatMulPrim(stack_cnode);
if (matmul_cvalue == nullptr) {
MS_LOG(ERROR) << "new MatMul failed";
@ -199,6 +274,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};
// batchmatmul right node may be const
bool right_transpose = false;
if (right_reshape_node->isa<Parameter>()) {
auto rmatmul_paramter = func_graph->add_parameter();
if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
@ -206,24 +282,47 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
return node;
}
auto prim = GetValueNode<PrimitiveCPtr>(matmul_value_node);
MS_ASSERT(prim != nullptr);
auto prim_matmul = prim->cast<std::shared_ptr<mindspore::ops::MatMul>>();
MS_ASSERT(prim_matmul != nullptr);
if (prim_matmul == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
prim_matmul->set_transpose_b(true);
matmul_inputs.push_back(rmatmul_paramter);
} else if (ConnectTransposeConcat(right_reshape_node)) {
right_transpose = true;
auto ret = ResetReshapeParameters(right_reshape_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "reset reshape parameters failed";
return nullptr;
}
matmul_inputs.push_back(right_reshape_node);
} else {
auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
MS_ASSERT(right_reshape_cnode->inputs().size() > 1);
auto right_transpose_node = right_reshape_cnode->input(1);
auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
if (right_transpose_cnode == nullptr) {
MS_LOG(ERROR) << "transpose cnode is null";
return nullptr;
}
auto right_slice_node = right_transpose_cnode->input(1);
auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
if (right_slice_cnode == nullptr) {
MS_LOG(ERROR) << "slice cnode is null";
return nullptr;
}
auto right_matmul_input = right_slice_cnode->input(1);
matmul_inputs.push_back(right_matmul_input);
}
auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
matmul_cnode->set_abstract(stack_cnode->abstract()->Clone());
if (right_transpose) {
auto matmul_primitive = GetValueNode<std::shared_ptr<ops::MatMul>>(matmul_cnode->input(0));
matmul_primitive->set_transpose_b(true);
}
MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
return matmul_cnode;
}