bias add fusion

This commit is contained in:
zhang__sss 2021-04-21 16:28:20 +08:00
parent 8487ce0c09
commit 08068e6a7a
3 changed files with 160 additions and 27 deletions

View File

@ -417,6 +417,15 @@ int CheckIfNodeIsParam(const AnfNodePtr &node) {
return lite::RET_OK;
}
int CheckIfNodeIsParamOrValue(const AnfNodePtr &node) {
if (node == nullptr || (node != nullptr && !utils::isa<ParameterPtr>(node) && !utils::isa<ValueNode>(node))) {
MS_LOG(DEBUG) << "The Node is not param or value node.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
return lite::RET_INVALID_OP_ATTR;
}
return lite::RET_OK;
}
int CheckInputSize(const CNodePtr &node, const int size) {
if (static_cast<int>(node->inputs().size()) != size) {
MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
@ -534,6 +543,31 @@ bool IsParamNode(const BaseRef &n) {
return tensor->data_c() != nullptr;
}
bool IsParamOrValueNodeWithData(const BaseRef &n) {
if (utils::isa<ValueNode>(n)) {
auto value_node = utils::cast<ValueNodePtr>(n);
auto value = value_node->value();
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
if (tensor == nullptr || tensor->data_c() == nullptr) {
return false;
}
return true;
} else {
return false;
}
}
if (utils::isa<ParameterPtr>(n)) {
auto param = utils::cast<ParameterPtr>(n)->default_param();
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param);
if (tensor == nullptr || tensor->data_c() == nullptr) {
return false;
}
return true;
}
return false;
}
bool IsConvNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);

View File

@ -63,6 +63,8 @@ int CheckInputSize(const CNodePtr &node, int size);
int CheckIfNodeIsParam(const AnfNodePtr &node);
int CheckIfNodeIsParamOrValue(const AnfNodePtr &node);
int CheckLeastInputSize(const CNodePtr &node, int size);
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
@ -70,6 +72,8 @@ ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, in
bool IsParamNode(const BaseRef &n);
bool IsParamOrValueNodeWithData(const BaseRef &n);
bool IsConvNode(const BaseRef &n);
bool IsPoolingNode(const BaseRef &n);

View File

@ -39,6 +39,7 @@ bool IsConvExtendNode(const BaseRef &n) {
}
return false;
}
bool IsAddNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
@ -71,6 +72,115 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) {
return 0;
}
}
int GetAddBiasData(const AnfNodePtr &bias_add_weight_node, const int &kernel_nums, float **add_bias_data) {
MS_ASSERT(bias_add_weight_node != nullptr);
MS_ASSERT(add_bias_data != nullptr);
MS_ASSERT(*add_bias_data != nullptr);
float *add_weight_data = nullptr;
ShapeVector add_weight_shape;
if (utils::isa<Parameter>(bias_add_weight_node)) {
auto add_weight_param_node = bias_add_weight_node->cast<ParameterPtr>();
if (!add_weight_param_node->has_default() || add_weight_param_node->default_param() == nullptr) {
MS_LOG(ERROR) << "The bias parameter of " << bias_add_weight_node->fullname_with_scope() << " is nullptr.";
return lite::RET_ERROR;
}
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param_node->default_param());
if (add_weight_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of parameter node " << bias_add_weight_node->fullname_with_scope()
<< " is not tensorPtr.";
return lite::RET_ERROR;
}
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
MS_ASSERT(add_weight_data != nullptr);
add_weight_shape = add_weight_tensor->shape();
} else {
MS_ASSERT(utils::isa<ValueNode>(bias_add_weight_node));
auto add_weight_value_node = bias_add_weight_node->cast<ValueNodePtr>();
auto add_weight_value = add_weight_value_node->value();
MS_ASSERT(add_weight_value != nullptr);
auto add_weight_tensor = add_weight_value->cast<tensor::TensorPtr>();
if (add_weight_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of value node " << bias_add_weight_node->fullname_with_scope()
<< " is not tensorPtr.";
return lite::RET_ERROR;
}
add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
MS_ASSERT(add_weight_data != nullptr);
auto value_abstract = add_weight_value_node->abstract();
auto value_abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_abstract);
add_weight_shape = utils::cast<abstract::ShapePtr>(value_abstract_tensor->BuildShape())->shape();
}
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
for (int i = 0; i < kernel_nums; i++) {
(*add_bias_data)[i] = *add_weight_data;
}
} else {
if (EOK != memcpy_s(*add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
int GetNewConvBiasData(const AnfNodePtr &conv_bias_node, const int &kernel_nums, const float *add_bias_data) {
MS_ASSERT(add_bias_data != nullptr);
MS_ASSERT(conv_bias_node != nullptr);
if (utils::isa<Parameter>(conv_bias_node)) {
auto conv_bias_param_node = conv_bias_node->cast<ParameterPtr>();
if (!conv_bias_param_node->has_default() || conv_bias_param_node->default_param() == nullptr) {
MS_LOG(ERROR) << "The bias parameter of " << conv_bias_node->fullname_with_scope() << " is nullptr.";
return lite::RET_ERROR;
}
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param_node->default_param());
if (conv_bias_tensor == nullptr || conv_bias_tensor->shape().empty() ||
conv_bias_tensor->shape()[0] != kernel_nums) {
MS_LOG(ERROR) << "conv_bias_node shape error";
return lite::RET_ERROR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
MS_ASSERT(conv_bias_data != nullptr);
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
} else {
MS_ASSERT(utils::isa<ValueNode>(conv_bias_node));
auto conv_bias_value_node = conv_bias_node->cast<ValueNodePtr>();
auto conv_bias_value = conv_bias_value_node->value();
MS_ASSERT(conv_bias_value != nullptr);
auto conv_bias_tensor = conv_bias_value->cast<tensor::TensorPtr>();
if (conv_bias_tensor == nullptr) {
MS_LOG(ERROR) << "The bias data of value node " << conv_bias_node->fullname_with_scope() << "is not tensorPtr.";
return lite::RET_ERROR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
MS_ASSERT(conv_bias_data != nullptr);
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
}
return lite::RET_OK;
}
tensor::TensorPtr GetConvWeightTensor(const AnfNodePtr &conv_weight_node) {
tensor::TensorPtr conv_weight_tensor;
if (utils::isa<ValueNode>(conv_weight_node)) {
auto conv_weight_value_node = conv_weight_node->cast<ValueNodePtr>();
auto conv_weight_value = conv_weight_value_node->value();
MS_ASSERT(conv_weight_value != nullptr);
conv_weight_tensor = conv_weight_value->cast<tensor::TensorPtr>();
MS_ASSERT(conv_weight_tensor != nullptr);
} else {
MS_ASSERT(utils::isa<Parameter>(conv_weight_node));
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
MS_ASSERT(conv_weight_param != nullptr);
conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
MS_ASSERT(conv_weight_tensor != nullptr);
}
return conv_weight_tensor;
}
int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(conv_node != nullptr);
@ -97,45 +207,30 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
return lite::RET_MEMORY_FAILED;
}
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
if (CheckIfNodeIsParam(bias_add_weight) != lite::RET_OK) {
if (CheckIfNodeIsParamOrValue(bias_add_weight) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param();
auto add_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(add_weight_param);
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data_c());
auto add_weight_shape = add_weight_tensor->shape();
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) {
for (int i = 0; i < kernel_nums; i++) {
add_bias_data[i] = *add_weight_data;
}
} else {
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
delete[] add_bias_data;
return lite::RET_MEMORY_FAILED;
}
if (GetAddBiasData(bias_add_weight, kernel_nums, &add_bias_data) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
if (conv_bias_node != nullptr) {
if (CheckIfNodeIsParam(conv_bias_node) != lite::RET_OK) {
if (CheckIfNodeIsParamOrValue(conv_bias_node) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
auto conv_bias_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_bias_param);
if (conv_bias_tensor->shape().empty() || conv_bias_tensor->shape()[0] != kernel_nums) {
MS_LOG(ERROR) << "conv_bias_node shape error";
if (GetNewConvBiasData(conv_bias_node, kernel_nums, add_bias_data) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->data_c());
for (int i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
delete[] add_bias_data;
} else {
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
auto conv_weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(conv_weight_param);
if (CheckIfNodeIsParamOrValue(conv_weight_node) != lite::RET_OK) {
delete[] add_bias_data;
return lite::RET_INVALID_OP_ATTR;
}
tensor::TensorPtr conv_weight_tensor = GetConvWeightTensor(conv_weight_node);
auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor);
conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias");
conv_node->add_input(conv_new_bias);
@ -146,7 +241,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
const BaseRef ConvBiasaddFusion::DefinePattern() const {
auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
auto add_var = std::make_shared<CondVar>(IsAddNode);
auto weight_var = std::make_shared<CondVar>(IsParamNode);
auto weight_var = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
return VectorRef({add_var, conv_var, weight_var});
}