|
|
|
@ -23,9 +23,11 @@
|
|
|
|
|
#include "tools/converter/quant_param_holder.h"
|
|
|
|
|
#include "tools/optimizer/common/gllo_utils.h"
|
|
|
|
|
#include "securec/include/securec.h"
|
|
|
|
|
#include "src/common/log_adapter.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::opt {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr int64_t kFcRightInputDims = 3;
|
|
|
|
|
bool IsStackNode(const BaseRef &n) {
|
|
|
|
|
if (utils::isa<AnfNodePtr>(n)) {
|
|
|
|
|
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimStack);
|
|
|
|
@ -63,10 +65,14 @@ void *GetInputAddr(const AnfNodePtr &node, size_t input_index) {
|
|
|
|
|
}
|
|
|
|
|
STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPtr &rmatmul_input) {
|
|
|
|
|
MS_ASSERT(stack_node != nullptr);
|
|
|
|
|
MS_ASSERT(right_matmul_input != nullptr);
|
|
|
|
|
MS_ASSERT(rmatmul_input != nullptr);
|
|
|
|
|
auto joint_fullconnect_size = stack_node->inputs().size() - 1;
|
|
|
|
|
auto fc = stack_node->input(1)->cast<CNodePtr>();
|
|
|
|
|
auto fc_weight = fc->input(2)->cast<ParameterPtr>();
|
|
|
|
|
if (fc_weight == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "fully-connected weight is null";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto fc_weight_param = std::dynamic_pointer_cast<tensor::Tensor>(fc_weight->default_param());
|
|
|
|
|
auto tensor_size = fc_weight_param->Size();
|
|
|
|
|
auto rmatmul_input_shape = fc_weight_param->shape();
|
|
|
|
@ -157,12 +163,74 @@ const BaseRef BatchMatMulFusion::DefinePattern() const {
|
|
|
|
|
return VectorRef({pack_var, left_fullconnect_var, right_fullconnect_var, other_fullconnect_var});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ConnectTransposeConcat(const AnfNodePtr &node) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "cnode is null";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto right_transpose_node = cnode->input(1);
|
|
|
|
|
auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
|
|
|
|
|
if (right_transpose_cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "cnode is null";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto front_node = right_transpose_cnode->input(1);
|
|
|
|
|
auto front_cnode = front_node->cast<CNodePtr>();
|
|
|
|
|
if (front_cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "cnode is null";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (CheckPrimitiveType(right_transpose_cnode, prim::kPrimTranspose) &&
|
|
|
|
|
(CheckPrimitiveType(front_cnode, prim::kPrimTranspose) || CheckPrimitiveType(front_cnode, prim::kPrimConcat))) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ResetReshapeParameters(const AnfNodePtr &reshape_node) {
|
|
|
|
|
auto reshape_cnode = reshape_node->cast<CNodePtr>();
|
|
|
|
|
MS_ASSERT(reshape_cnode != nullptr);
|
|
|
|
|
auto reshape_shape_param = reshape_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
|
|
|
|
|
MS_ASSERT(reshape_shape_param != nullptr);
|
|
|
|
|
auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(reshape_shape_param->default_param());
|
|
|
|
|
auto rmatmul_input_shape = shape_tensor->shape();
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> shape(1, 0);
|
|
|
|
|
if (rmatmul_input_shape.size() <= 0) {
|
|
|
|
|
MS_LOG(ERROR) << "Create tensor info failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
} else if (shape[0] < kFcRightInputDims) {
|
|
|
|
|
shape[0] = rmatmul_input_shape[0] + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto tensor_info = std::make_shared<tensor::Tensor>(shape_tensor->data_type(), shape);
|
|
|
|
|
if (tensor_info == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Create tensor info failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int *tensor_data = reinterpret_cast<int *>(tensor_info->data_c());
|
|
|
|
|
tensor_data[0] = 1;
|
|
|
|
|
int *reshape_data = reinterpret_cast<int *>(shape_tensor->data_c());
|
|
|
|
|
for (int64_t i = 1; i < shape[0]; ++i) {
|
|
|
|
|
tensor_data[i] = reshape_data[i - 1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lite::InitParameterFromTensorInfo(reshape_shape_param, tensor_info);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// slice +fullconnect ->batchmatmul
|
|
|
|
|
const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &) const {
|
|
|
|
|
MS_ASSERT(func_graph != nullptr);
|
|
|
|
|
MS_ASSERT(node != nullptr);
|
|
|
|
|
auto stack_cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (stack_cnode == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "stack cnode is null";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// check stack node all inputs must fullconnect
|
|
|
|
|
for (size_t i = 1; i < stack_cnode->inputs().size(); i++) {
|
|
|
|
|
auto input_node = stack_cnode->input(i);
|
|
|
|
@ -172,10 +240,13 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto fullconnect_node = stack_cnode->input(1);
|
|
|
|
|
MS_ASSERT(fullconnnect_node != nullptr);
|
|
|
|
|
auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
|
|
|
|
|
MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
|
|
|
|
|
auto left_slice_node = fullconnect_cnode->input(1);
|
|
|
|
|
if (left_slice_node == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "slice node is null";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
|
|
|
|
|
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimSliceFusion)) {
|
|
|
|
|
if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
|
|
|
|
@ -189,6 +260,10 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
|
|
|
|
}
|
|
|
|
|
auto left_matmul_input = left_slice_cnode->input(1);
|
|
|
|
|
auto right_reshape_node = fullconnect_cnode->input(2);
|
|
|
|
|
if (left_matmul_input == nullptr || right_reshape_node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "matmul's input is null";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto matmul_cvalue = BuildMatMulPrim(stack_cnode);
|
|
|
|
|
if (matmul_cvalue == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new MatMul failed";
|
|
|
|
@ -199,6 +274,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
|
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};
|
|
|
|
|
|
|
|
|
|
// batchmatmul right node may be const
|
|
|
|
|
bool right_transpose = false;
|
|
|
|
|
if (right_reshape_node->isa<Parameter>()) {
|
|
|
|
|
auto rmatmul_paramter = func_graph->add_parameter();
|
|
|
|
|
if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
|
|
|
|
@ -206,24 +282,47 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
auto prim = GetValueNode<PrimitiveCPtr>(matmul_value_node);
|
|
|
|
|
MS_ASSERT(prim != nullptr);
|
|
|
|
|
auto prim_matmul = prim->cast<std::shared_ptr<mindspore::ops::MatMul>>();
|
|
|
|
|
MS_ASSERT(prim_matmul != nullptr);
|
|
|
|
|
if (prim_matmul == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive is null";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
prim_matmul->set_transpose_b(true);
|
|
|
|
|
matmul_inputs.push_back(rmatmul_paramter);
|
|
|
|
|
} else if (ConnectTransposeConcat(right_reshape_node)) {
|
|
|
|
|
right_transpose = true;
|
|
|
|
|
auto ret = ResetReshapeParameters(right_reshape_node);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "reset reshape parameters failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
matmul_inputs.push_back(right_reshape_node);
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
|
|
|
|
|
MS_ASSERT(right_reshape_cnode->inputs().size() > 1);
|
|
|
|
|
auto right_transpose_node = right_reshape_cnode->input(1);
|
|
|
|
|
auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
|
|
|
|
|
if (right_transpose_cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "transpose cnode is null";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto right_slice_node = right_transpose_cnode->input(1);
|
|
|
|
|
auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
|
|
|
|
|
if (right_slice_cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "slice cnode is null";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto right_matmul_input = right_slice_cnode->input(1);
|
|
|
|
|
matmul_inputs.push_back(right_matmul_input);
|
|
|
|
|
}
|
|
|
|
|
auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
|
|
|
|
|
matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
|
|
|
|
|
matmul_cnode->set_abstract(stack_cnode->abstract()->Clone());
|
|
|
|
|
if (right_transpose) {
|
|
|
|
|
auto matmul_primitive = GetValueNode<std::shared_ptr<ops::MatMul>>(matmul_cnode->input(0));
|
|
|
|
|
matmul_primitive->set_transpose_b(true);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
|
|
|
|
|
return matmul_cnode;
|
|
|
|
|
}
|
|
|
|
|