forked from mindspore-Ecosystem/mindspore
!23842 add fullyconnected fusion
Merge pull request !23842 from wangyanling/r1.3
This commit is contained in:
commit
21509c6d8d
|
@ -25,8 +25,9 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
|
||||||
return NNACL_PARAM_INVALID;
|
return NNACL_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) {
|
for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) {
|
||||||
int cmp_value = MSMIN(a_shape[i], b_shape[i]);
|
int min_value = MSMIN(a_shape[i], b_shape[i]);
|
||||||
if (a_shape[i] != b_shape[i] && cmp_value != 1) {
|
int max_value = MSMAX(a_shape[i], b_shape[i]);
|
||||||
|
if (max_value % min_value != 0) {
|
||||||
return NNACL_INPUT_TENSOR_ERROR;
|
return NNACL_INPUT_TENSOR_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,15 +116,16 @@ int MatmulFP16CPUKernel::InitBroadcastParams() {
|
||||||
int out_batch = 1;
|
int out_batch = 1;
|
||||||
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
|
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
|
||||||
out_batch *= MSMAX(a_shape[i], b_shape[i]);
|
out_batch *= MSMAX(a_shape[i], b_shape[i]);
|
||||||
if (a_shape[i] < b_shape[i] && a_shape[i] == 1) {
|
if (a_shape[i] < b_shape[i] && b_shape[i] % a_shape[i] == 0) {
|
||||||
a_broadcast_ = true;
|
a_broadcast_ = true;
|
||||||
} else if (a_shape[i] > b_shape[i] && b_shape[i] == 1) {
|
} else if (a_shape[i] > b_shape[i] && a_shape[i] % b_shape[i] == 0) {
|
||||||
b_broadcast_ = true;
|
b_broadcast_ = true;
|
||||||
} else if (a_shape[i] != b_shape[i]) {
|
} else if (a_shape[i] != b_shape[i]) {
|
||||||
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
|
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
params_->batch = out_batch;
|
params_->batch = out_batch;
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,9 +104,9 @@ int MatmulCPUKernel::InitBroadcastParams() {
|
||||||
int out_batch = 1;
|
int out_batch = 1;
|
||||||
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
|
for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
|
||||||
out_batch *= MSMAX(a_shape[i], b_shape[i]);
|
out_batch *= MSMAX(a_shape[i], b_shape[i]);
|
||||||
if (a_shape[i] < b_shape[i] && a_shape[i] == 1) {
|
if (a_shape[i] < b_shape[i] && b_shape[i] % a_shape[i] == 0) {
|
||||||
a_broadcast_ = true;
|
a_broadcast_ = true;
|
||||||
} else if (a_shape[i] > b_shape[i] && b_shape[i] == 1) {
|
} else if (a_shape[i] > b_shape[i] && a_shape[i] % b_shape[i] == 0) {
|
||||||
b_broadcast_ = true;
|
b_broadcast_ = true;
|
||||||
} else if (a_shape[i] != b_shape[i]) {
|
} else if (a_shape[i] != b_shape[i]) {
|
||||||
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
|
MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
|
||||||
|
|
|
@ -106,3 +106,4 @@ hiai_nlu_model_single.pb;3;1,32:1,32:1,32
|
||||||
fsr_270_mindspore.pb
|
fsr_270_mindspore.pb
|
||||||
fsr_360_mindspore.pb
|
fsr_360_mindspore.pb
|
||||||
fsr_720_mindspore.pb
|
fsr_720_mindspore.pb
|
||||||
|
tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder.pb;14;4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640
|
||||||
|
|
|
@ -91,3 +91,4 @@ hiai_nlu_model_single.pb;3;1,32:1,32:1,32 3.5
|
||||||
fsr_270_mindspore.pb 6.0
|
fsr_270_mindspore.pb 6.0
|
||||||
fsr_360_mindspore.pb 6.5
|
fsr_360_mindspore.pb 6.5
|
||||||
fsr_720_mindspore.pb 2.0
|
fsr_720_mindspore.pb 2.0
|
||||||
|
tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder.pb;14;4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 0.5
|
||||||
|
|
|
@ -203,3 +203,7 @@ add_uint8.tflite;2
|
||||||
coco_ssd_mobilenet_v1_1.0.tflite
|
coco_ssd_mobilenet_v1_1.0.tflite
|
||||||
hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32.tflite;2
|
hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32.tflite;2
|
||||||
hiai_asr_ctc.tflite;2
|
hiai_asr_ctc.tflite;2
|
||||||
|
# weight quant model
|
||||||
|
tt_raw_h4800_mel80_ms_fe001_ex_20210506_encoder.tflite;25;1,15,80:1,15,80:1,15,80:1,15,80:1,15,80:1,15,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,31,80:1,640 8.5
|
||||||
|
# weight quant model
|
||||||
|
tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder.tflite;14;4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640:4 10.5
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
mobilenet.tflite 0.5
|
mobilenet.tflite 0.5
|
||||||
Change_input_transformer_20200831_encoder_fp32.tflite 70
|
Change_input_transformer_20200831_encoder_fp32.tflite 85
|
||||||
Change_input_transformer_20200831_decoder_fp32.tflite 35
|
Change_input_transformer_20200831_decoder_fp32.tflite 35
|
||||||
|
|
|
@ -23,9 +23,11 @@
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
|
constexpr int64_t kFcRightInputDims = 3;
|
||||||
bool IsStackNode(const BaseRef &n) {
|
bool IsStackNode(const BaseRef &n) {
|
||||||
if (utils::isa<AnfNodePtr>(n)) {
|
if (utils::isa<AnfNodePtr>(n)) {
|
||||||
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimStack);
|
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimStack);
|
||||||
|
@ -63,10 +65,14 @@ void *GetInputAddr(const AnfNodePtr &node, size_t input_index) {
|
||||||
}
|
}
|
||||||
STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPtr &rmatmul_input) {
|
STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPtr &rmatmul_input) {
|
||||||
MS_ASSERT(stack_node != nullptr);
|
MS_ASSERT(stack_node != nullptr);
|
||||||
MS_ASSERT(right_matmul_input != nullptr);
|
MS_ASSERT(rmatmul_input != nullptr);
|
||||||
auto joint_fullconnect_size = stack_node->inputs().size() - 1;
|
auto joint_fullconnect_size = stack_node->inputs().size() - 1;
|
||||||
auto fc = stack_node->input(1)->cast<CNodePtr>();
|
auto fc = stack_node->input(1)->cast<CNodePtr>();
|
||||||
auto fc_weight = fc->input(2)->cast<ParameterPtr>();
|
auto fc_weight = fc->input(2)->cast<ParameterPtr>();
|
||||||
|
if (fc_weight == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "fully-connected weight is null";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
auto fc_weight_param = std::dynamic_pointer_cast<tensor::Tensor>(fc_weight->default_param());
|
auto fc_weight_param = std::dynamic_pointer_cast<tensor::Tensor>(fc_weight->default_param());
|
||||||
auto tensor_size = fc_weight_param->Size();
|
auto tensor_size = fc_weight_param->Size();
|
||||||
auto rmatmul_input_shape = fc_weight_param->shape();
|
auto rmatmul_input_shape = fc_weight_param->shape();
|
||||||
|
@ -157,12 +163,74 @@ const BaseRef BatchMatMulFusion::DefinePattern() const {
|
||||||
return VectorRef({pack_var, left_fullconnect_var, right_fullconnect_var, other_fullconnect_var});
|
return VectorRef({pack_var, left_fullconnect_var, right_fullconnect_var, other_fullconnect_var});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ConnectTransposeConcat(const AnfNodePtr &node) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "cnode is null";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto right_transpose_node = cnode->input(1);
|
||||||
|
auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
|
||||||
|
if (right_transpose_cnode == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "cnode is null";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto front_node = right_transpose_cnode->input(1);
|
||||||
|
auto front_cnode = front_node->cast<CNodePtr>();
|
||||||
|
if (front_cnode == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "cnode is null";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (CheckPrimitiveType(right_transpose_cnode, prim::kPrimTranspose) &&
|
||||||
|
(CheckPrimitiveType(front_cnode, prim::kPrimTranspose) || CheckPrimitiveType(front_cnode, prim::kPrimConcat))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ResetReshapeParameters(const AnfNodePtr &reshape_node) {
|
||||||
|
auto reshape_cnode = reshape_node->cast<CNodePtr>();
|
||||||
|
MS_ASSERT(reshape_cnode != nullptr);
|
||||||
|
auto reshape_shape_param = reshape_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
|
||||||
|
MS_ASSERT(reshape_shape_param != nullptr);
|
||||||
|
auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(reshape_shape_param->default_param());
|
||||||
|
auto rmatmul_input_shape = shape_tensor->shape();
|
||||||
|
|
||||||
|
std::vector<int64_t> shape(1, 0);
|
||||||
|
if (rmatmul_input_shape.size() <= 0) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor info failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
} else if (shape[0] < kFcRightInputDims) {
|
||||||
|
shape[0] = rmatmul_input_shape[0] + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensor_info = std::make_shared<tensor::Tensor>(shape_tensor->data_type(), shape);
|
||||||
|
if (tensor_info == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor info failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
int *tensor_data = reinterpret_cast<int *>(tensor_info->data_c());
|
||||||
|
tensor_data[0] = 1;
|
||||||
|
int *reshape_data = reinterpret_cast<int *>(shape_tensor->data_c());
|
||||||
|
for (int64_t i = 1; i < shape[0]; ++i) {
|
||||||
|
tensor_data[i] = reshape_data[i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
lite::InitParameterFromTensorInfo(reshape_shape_param, tensor_info);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
// slice +fullconnect ->batchmatmul
|
// slice +fullconnect ->batchmatmul
|
||||||
const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
MS_ASSERT(func_graph != nullptr);
|
MS_ASSERT(func_graph != nullptr);
|
||||||
MS_ASSERT(node != nullptr);
|
MS_ASSERT(node != nullptr);
|
||||||
auto stack_cnode = node->cast<CNodePtr>();
|
auto stack_cnode = node->cast<CNodePtr>();
|
||||||
|
if (stack_cnode == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "stack cnode is null";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
// check stack node all inputs must fullconnect
|
// check stack node all inputs must fullconnect
|
||||||
for (size_t i = 1; i < stack_cnode->inputs().size(); i++) {
|
for (size_t i = 1; i < stack_cnode->inputs().size(); i++) {
|
||||||
auto input_node = stack_cnode->input(i);
|
auto input_node = stack_cnode->input(i);
|
||||||
|
@ -172,10 +240,13 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto fullconnect_node = stack_cnode->input(1);
|
auto fullconnect_node = stack_cnode->input(1);
|
||||||
MS_ASSERT(fullconnnect_node != nullptr);
|
|
||||||
auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
|
auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
|
||||||
MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
|
MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
|
||||||
auto left_slice_node = fullconnect_cnode->input(1);
|
auto left_slice_node = fullconnect_cnode->input(1);
|
||||||
|
if (left_slice_node == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "slice node is null";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
|
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
|
||||||
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimSliceFusion)) {
|
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimSliceFusion)) {
|
||||||
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
|
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
|
||||||
|
@ -189,6 +260,10 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
||||||
}
|
}
|
||||||
auto left_matmul_input = left_slice_cnode->input(1);
|
auto left_matmul_input = left_slice_cnode->input(1);
|
||||||
auto right_reshape_node = fullconnect_cnode->input(2);
|
auto right_reshape_node = fullconnect_cnode->input(2);
|
||||||
|
if (left_matmul_input == nullptr || right_reshape_node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "matmul's input is null";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto matmul_cvalue = BuildMatMulPrim(stack_cnode);
|
auto matmul_cvalue = BuildMatMulPrim(stack_cnode);
|
||||||
if (matmul_cvalue == nullptr) {
|
if (matmul_cvalue == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MatMul failed";
|
MS_LOG(ERROR) << "new MatMul failed";
|
||||||
|
@ -199,6 +274,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
||||||
std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};
|
std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};
|
||||||
|
|
||||||
// batchmatmul right node may be const
|
// batchmatmul right node may be const
|
||||||
|
bool right_transpose = false;
|
||||||
if (right_reshape_node->isa<Parameter>()) {
|
if (right_reshape_node->isa<Parameter>()) {
|
||||||
auto rmatmul_paramter = func_graph->add_parameter();
|
auto rmatmul_paramter = func_graph->add_parameter();
|
||||||
if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
|
if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
|
||||||
|
@ -206,24 +282,47 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
auto prim = GetValueNode<PrimitiveCPtr>(matmul_value_node);
|
auto prim = GetValueNode<PrimitiveCPtr>(matmul_value_node);
|
||||||
MS_ASSERT(prim != nullptr);
|
|
||||||
auto prim_matmul = prim->cast<std::shared_ptr<mindspore::ops::MatMul>>();
|
auto prim_matmul = prim->cast<std::shared_ptr<mindspore::ops::MatMul>>();
|
||||||
MS_ASSERT(prim_matmul != nullptr);
|
if (prim_matmul == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitive is null";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
prim_matmul->set_transpose_b(true);
|
prim_matmul->set_transpose_b(true);
|
||||||
matmul_inputs.push_back(rmatmul_paramter);
|
matmul_inputs.push_back(rmatmul_paramter);
|
||||||
|
} else if (ConnectTransposeConcat(right_reshape_node)) {
|
||||||
|
right_transpose = true;
|
||||||
|
auto ret = ResetReshapeParameters(right_reshape_node);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "reset reshape parameters failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
matmul_inputs.push_back(right_reshape_node);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
|
auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
|
||||||
MS_ASSERT(right_reshape_cnode->inputs().size() > 1);
|
MS_ASSERT(right_reshape_cnode->inputs().size() > 1);
|
||||||
auto right_transpose_node = right_reshape_cnode->input(1);
|
auto right_transpose_node = right_reshape_cnode->input(1);
|
||||||
auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
|
auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
|
||||||
|
if (right_transpose_cnode == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "transpose cnode is null";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto right_slice_node = right_transpose_cnode->input(1);
|
auto right_slice_node = right_transpose_cnode->input(1);
|
||||||
auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
|
auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
|
||||||
|
if (right_slice_cnode == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "slice cnode is null";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto right_matmul_input = right_slice_cnode->input(1);
|
auto right_matmul_input = right_slice_cnode->input(1);
|
||||||
matmul_inputs.push_back(right_matmul_input);
|
matmul_inputs.push_back(right_matmul_input);
|
||||||
}
|
}
|
||||||
auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
|
auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
|
||||||
matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
|
matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
|
||||||
matmul_cnode->set_abstract(stack_cnode->abstract()->Clone());
|
matmul_cnode->set_abstract(stack_cnode->abstract()->Clone());
|
||||||
|
if (right_transpose) {
|
||||||
|
auto matmul_primitive = GetValueNode<std::shared_ptr<ops::MatMul>>(matmul_cnode->input(0));
|
||||||
|
matmul_primitive->set_transpose_b(true);
|
||||||
|
}
|
||||||
MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
|
MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
|
||||||
return matmul_cnode;
|
return matmul_cnode;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue