forked from mindspore-Ecosystem/mindspore
bias add fusion
This commit is contained in:
parent
8487ce0c09
commit
08068e6a7a
|
@ -417,6 +417,15 @@ int CheckIfNodeIsParam(const AnfNodePtr &node) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int CheckIfNodeIsParamOrValue(const AnfNodePtr &node) {
|
||||
if (node == nullptr || (node != nullptr && !utils::isa<ParameterPtr>(node) && !utils::isa<ValueNode>(node))) {
|
||||
MS_LOG(DEBUG) << "The Node is not param or value node.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int CheckInputSize(const CNodePtr &node, const int size) {
|
||||
if (static_cast<int>(node->inputs().size()) != size) {
|
||||
MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
|
||||
|
@ -534,6 +543,31 @@ bool IsParamNode(const BaseRef &n) {
|
|||
return tensor->data_c() != nullptr;
|
||||
}
|
||||
|
||||
bool IsParamOrValueNodeWithData(const BaseRef &n) {
|
||||
if (utils::isa<ValueNode>(n)) {
|
||||
auto value_node = utils::cast<ValueNodePtr>(n);
|
||||
auto value = value_node->value();
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
if (tensor == nullptr || tensor->data_c() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (utils::isa<ParameterPtr>(n)) {
|
||||
auto param = utils::cast<ParameterPtr>(n)->default_param();
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param);
|
||||
if (tensor == nullptr || tensor->data_c() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsConvNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||
|
|
|
@ -63,6 +63,8 @@ int CheckInputSize(const CNodePtr &node, int size);
|
|||
|
||||
int CheckIfNodeIsParam(const AnfNodePtr &node);
|
||||
|
||||
int CheckIfNodeIsParamOrValue(const AnfNodePtr &node);
|
||||
|
||||
int CheckLeastInputSize(const CNodePtr &node, int size);
|
||||
|
||||
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
|
||||
|
@ -70,6 +72,8 @@ ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, in
|
|||
|
||||
bool IsParamNode(const BaseRef &n);
|
||||
|
||||
bool IsParamOrValueNodeWithData(const BaseRef &n);
|
||||
|
||||
bool IsConvNode(const BaseRef &n);
|
||||
|
||||
bool IsPoolingNode(const BaseRef &n);
|
||||
|
|
|
@ -39,6 +39,7 @@ bool IsConvExtendNode(const BaseRef &n) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsAddNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||
|
@ -71,6 +72,115 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) {
|
|||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int GetAddBiasData(const AnfNodePtr &bias_add_weight_node, const int &kernel_nums, float **add_bias_data) {
|
||||
MS_ASSERT(bias_add_weight_node != nullptr);
|
||||
MS_ASSERT(add_bias_data != nullptr);
|
||||
MS_ASSERT(*add_bias_data != nullptr);
|
||||
float *add_weight_data = nullptr;
|
||||
ShapeVector add_weight_shape;
|
||||
if (utils::isa<Parameter>(bias_add_weight_node)) {
|
||||
auto add_weight_param_node = bias_add_weight_node->cast<ParameterPtr>();
|
||||
if (!add_weight_param_node->has_default() || add_weight_param_node->default_param() == nullptr) {
|
||||
MS_LOG(ERROR) << "The bias parameter of " << bias_add_weight_node->fullname_with_scope() << " is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param_node->default_param());
|
||||
if (add_weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "The bias data of parameter node " << bias_add_weight_node->fullname_with_scope()
|
||||
<< " is not tensorPtr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
|
||||
MS_ASSERT(add_weight_data != nullptr);
|
||||
add_weight_shape = add_weight_tensor->shape();
|
||||
} else {
|
||||
MS_ASSERT(utils::isa<ValueNode>(bias_add_weight_node));
|
||||
auto add_weight_value_node = bias_add_weight_node->cast<ValueNodePtr>();
|
||||
auto add_weight_value = add_weight_value_node->value();
|
||||
MS_ASSERT(add_weight_value != nullptr);
|
||||
auto add_weight_tensor = add_weight_value->cast<tensor::TensorPtr>();
|
||||
if (add_weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "The bias data of value node " << bias_add_weight_node->fullname_with_scope()
|
||||
<< " is not tensorPtr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
|
||||
MS_ASSERT(add_weight_data != nullptr);
|
||||
auto value_abstract = add_weight_value_node->abstract();
|
||||
auto value_abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_abstract);
|
||||
add_weight_shape = utils::cast<abstract::ShapePtr>(value_abstract_tensor->BuildShape())->shape();
|
||||
}
|
||||
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
|
||||
for (int i = 0; i < kernel_nums; i++) {
|
||||
(*add_bias_data)[i] = *add_weight_data;
|
||||
}
|
||||
} else {
|
||||
if (EOK != memcpy_s(*add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int GetNewConvBiasData(const AnfNodePtr &conv_bias_node, const int &kernel_nums, const float *add_bias_data) {
|
||||
MS_ASSERT(add_bias_data != nullptr);
|
||||
MS_ASSERT(conv_bias_node != nullptr);
|
||||
if (utils::isa<Parameter>(conv_bias_node)) {
|
||||
auto conv_bias_param_node = conv_bias_node->cast<ParameterPtr>();
|
||||
if (!conv_bias_param_node->has_default() || conv_bias_param_node->default_param() == nullptr) {
|
||||
MS_LOG(ERROR) << "The bias parameter of " << conv_bias_node->fullname_with_scope() << " is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param_node->default_param());
|
||||
if (conv_bias_tensor == nullptr || conv_bias_tensor->shape().empty() ||
|
||||
conv_bias_tensor->shape()[0] != kernel_nums) {
|
||||
MS_LOG(ERROR) << "conv_bias_node shape error";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
|
||||
MS_ASSERT(conv_bias_data != nullptr);
|
||||
for (int i = 0; i < kernel_nums; i++) {
|
||||
conv_bias_data[i] += add_bias_data[i];
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(utils::isa<ValueNode>(conv_bias_node));
|
||||
auto conv_bias_value_node = conv_bias_node->cast<ValueNodePtr>();
|
||||
auto conv_bias_value = conv_bias_value_node->value();
|
||||
MS_ASSERT(conv_bias_value != nullptr);
|
||||
auto conv_bias_tensor = conv_bias_value->cast<tensor::TensorPtr>();
|
||||
if (conv_bias_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "The bias data of value node " << conv_bias_node->fullname_with_scope() << "is not tensorPtr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
|
||||
MS_ASSERT(conv_bias_data != nullptr);
|
||||
for (int i = 0; i < kernel_nums; i++) {
|
||||
conv_bias_data[i] += add_bias_data[i];
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
tensor::TensorPtr GetConvWeightTensor(const AnfNodePtr &conv_weight_node) {
|
||||
tensor::TensorPtr conv_weight_tensor;
|
||||
if (utils::isa<ValueNode>(conv_weight_node)) {
|
||||
auto conv_weight_value_node = conv_weight_node->cast<ValueNodePtr>();
|
||||
auto conv_weight_value = conv_weight_value_node->value();
|
||||
MS_ASSERT(conv_weight_value != nullptr);
|
||||
conv_weight_tensor = conv_weight_value->cast<tensor::TensorPtr>();
|
||||
MS_ASSERT(conv_weight_tensor != nullptr);
|
||||
} else {
|
||||
MS_ASSERT(utils::isa<Parameter>(conv_weight_node));
|
||||
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
|
||||
MS_ASSERT(conv_weight_param != nullptr);
|
||||
conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
|
||||
MS_ASSERT(conv_weight_tensor != nullptr);
|
||||
}
|
||||
return conv_weight_tensor;
|
||||
}
|
||||
|
||||
int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(conv_node != nullptr);
|
||||
|
@ -97,45 +207,30 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
|
|||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
|
||||
if (CheckIfNodeIsParam(bias_add_weight) != lite::RET_OK) {
|
||||
if (CheckIfNodeIsParamOrValue(bias_add_weight) != lite::RET_OK) {
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param();
|
||||
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param);
|
||||
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
|
||||
auto add_weight_shape = add_weight_tensor->shape();
|
||||
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
|
||||
for (int i = 0; i < kernel_nums; i++) {
|
||||
add_bias_data[i] = *add_weight_data;
|
||||
}
|
||||
} else {
|
||||
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
if (GetAddBiasData(bias_add_weight, kernel_nums, &add_bias_data) != lite::RET_OK) {
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
if (conv_bias_node != nullptr) {
|
||||
if (CheckIfNodeIsParam(conv_bias_node) != lite::RET_OK) {
|
||||
if (CheckIfNodeIsParamOrValue(conv_bias_node) != lite::RET_OK) {
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
|
||||
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param);
|
||||
if (conv_bias_tensor->shape().empty() || conv_bias_tensor->shape()[0] != kernel_nums) {
|
||||
MS_LOG(ERROR) << "conv_bias_node shape error";
|
||||
if (GetNewConvBiasData(conv_bias_node, kernel_nums, add_bias_data) != lite::RET_OK) {
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
|
||||
for (int i = 0; i < kernel_nums; i++) {
|
||||
conv_bias_data[i] += add_bias_data[i];
|
||||
}
|
||||
delete[] add_bias_data;
|
||||
} else {
|
||||
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
|
||||
auto conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
|
||||
if (CheckIfNodeIsParamOrValue(conv_weight_node) != lite::RET_OK) {
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
tensor::TensorPtr conv_weight_tensor = GetConvWeightTensor(conv_weight_node);
|
||||
auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor);
|
||||
conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias");
|
||||
conv_node->add_input(conv_new_bias);
|
||||
|
@ -146,7 +241,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
|
|||
const BaseRef ConvBiasaddFusion::DefinePattern() const {
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
|
||||
auto add_var = std::make_shared<CondVar>(IsAddNode);
|
||||
auto weight_var = std::make_shared<CondVar>(IsParamNode);
|
||||
auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
|
||||
return VectorRef({add_var, conv_var, weight_var});
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue