add anf pass

This commit is contained in:
zhengjun10 2020-08-01 19:37:29 +08:00
parent d7bc28dcde
commit 1561462925
15 changed files with 699 additions and 175 deletions

View File

@ -22,11 +22,8 @@
using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
namespace mindspore {
namespace opt {
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<Var>(a) && utils::isa<Parameter>(b)) {
return true;
} else if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
MS_EXCEPTION_IF_NULL(a_node);
@ -80,6 +77,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>();
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
}
return a == b;
}
@ -202,22 +200,72 @@ void CheckIfVarIsNull(const VarPtr &var) {
}
}
void CheckIfNodeIsParam(const AnfNodePtr &node) {
if (node != nullptr && !utils::isa<ParameterPtr>(node)) {
MS_LOG(EXCEPTION) << "The Node is not param.";
}
}
void CheckInputSize(const CNodePtr &node, const int size) {
if (node->inputs().size() != size) {
MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
}
}
schema::PrimitiveType GetCNodeType(const CNodePtr &node) {
auto value_primitive = node->input(0);
auto value_node = value_primitive->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto value = value_node->value();
MS_ASSERT(value != nullptr);
void CheckLeastInputSize(const CNodePtr &node, const int size) {
if (node->inputs().size() < size) {
MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
}
}
AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
const ParamValueLitePtr &weight_tensor) {
auto bias_parameter = func_graph->add_parameter();
MS_ASSERT(bias_parameter != nullptr);
std::vector<int> shape = {kernel_num};
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(weight_tensor->tensor_type()), shape);
bias_parameter->set_abstract(abstract_tensor);
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
MS_ASSERT(param_value != nullptr);
param_value->set_tensor_addr(bias_data);
param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t));
bias_parameter->set_default_param(param_value);
return bias_parameter;
}
schema::PrimitiveType GetCNodeType(const BaseRef &n) {
ValueNodePtr value_node;
if (utils::isa<CNodePtr>(n)) {
auto in = utils::cast<CNodePtr>(n);
value_node = in->input(0)->cast<ValueNodePtr>();
} else if (utils::isa<ValueNodePtr>(n)) {
value_node = utils::cast<ValueNodePtr>(n);
} else {
MS_LOG(EXCEPTION) << "only value node or cnode has type";
return schema::PrimitiveType_NONE;
}
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_ASSERT(value != nullptr);
if (utils::isa<PrimitiveTValuePtr>(value)) {
auto primitive = value->cast<PrimitiveTValuePtr>();
MS_ASSERT(primitive != nullptr);
return primitive->GetPrimitiveT()->value.type;
}
return schema::PrimitiveType_NONE;
}
bool IsParamNode(const BaseRef &n) {
return utils::isa<ParameterPtr>(n);
}
bool IsConvNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D;
}
return false;
}
} // namespace opt
} // namespace mindspore

View File

@ -17,12 +17,16 @@
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#include <mindspore/lite/src/ir/primitive_t_value.h>
#include <memory>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "src/gllo/common/pattern_engine.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"
using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
namespace mindspore {
namespace opt {
@ -41,11 +45,20 @@ void CheckIfCNodeIsNull(const CNodePtr &node);
void CheckIfVarIsNull(const VarPtr &var);
void CheckInputSize(const CNodePtr &node, const int size);
void CheckInputSize(const CNodePtr &node, int size);
schema::PrimitiveType GetCNodeType(const CNodePtr &node);
void CheckIfNodeIsParam(const AnfNodePtr &node);
void CheckLeastInputSize(const CNodePtr &node, int size);
AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
const ParamValueLitePtr &weight_tensor);
schema::PrimitiveType GetCNodeType(const BaseRef &node);
bool IsParamNode(const BaseRef &n);
bool IsConvNode(const BaseRef &n);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_

View File

@ -14,23 +14,24 @@
* limitations under the License.
*/
#include "src/gllo/fusion/conv_activation_fusion.h"
#include "mindspore/lite/src/gllo/fusion/conv_activation_fusion.h"
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/gllo/common/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
namespace mindspore {
namespace opt {
namespace mindspore::opt {
namespace {
constexpr size_t kActivationInputsLength = 2;
}
const BaseRef ConvActivationFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
// conv2d inputs may be 2 or 3 inputs,match move to process
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto prim = new schema::PrimitiveT();
prim->value.type = primitive_type;
auto prim_value = std::make_shared<lite::PrimitiveTValue>(prim);
return VectorRef({prim_value, X});
return VectorRef({prim_value, conv_var});
}
const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@ -41,7 +42,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
CheckIfAnfNodeIsNull(node);
auto act_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(act_node);
CheckInputSize(act_node, 2);
CheckInputSize(act_node, kActivationInputsLength);
auto act_primitive = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(act_node->input(0));
if (act_primitive->GetPrimitiveT()->value.AsActivation()->type != activation_type) {
@ -52,13 +53,18 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
if (pre_node != nullptr && pre_node->isa<CNode>()) {
auto conv_node = pre_node->cast<CNodePtr>();
auto node_type = GetCNodeType(conv_node);
if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) {
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
MS_ASSERT(primitiveT_value);
if (node_type == schema::PrimitiveType_Conv2D) {
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->activationType = activation_type;
return pre_node;
} else if (node_type == schema::PrimitiveType_DepthwiseConv2D) {
primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->activationType = activation_type;
return pre_node;
} else {
MS_LOG(EXCEPTION) << "conv activation pass match only conv2d or depthwise_conv2d ";
}
}
return node;
}
} // namespace opt
} // namespace mindspore
} // namespace mindspore::opt

View File

@ -17,16 +17,17 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
#include "src/gllo/common/optimizer.h"
#include <string>
#include "mindspore/lite/src/gllo/common/optimizer.h"
namespace mindspore {
namespace opt {
class ConvActivationFusion : public PatternProcessPass {
public:
explicit ConvActivationFusion(bool multigraph = true,
explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion",
schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU,
schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) : primitive_type(
primitive), activation_type(activation), PatternProcessPass("conv_activation_fusion", multigraph) {}
primitive), activation_type(activation), PatternProcessPass(name, multigraph) {}
~ConvActivationFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

View File

@ -13,35 +13,116 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/gllo/fusion/conv_biasadd_fusion.h"
#include "mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h"
#include <mindspore/lite/src/param_value_lite.h>
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/gllo/common/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
namespace mindspore {
namespace opt {
namespace mindspore::opt {
namespace {
constexpr size_t kAddInputsLength = 3;
constexpr size_t kAddWEIGHTINDEX = 2;
constexpr size_t kConvWeightIndex = 2;
constexpr size_t kConvBiasIndex = 3;
constexpr size_t kConvNoBiasLen = 3;
constexpr size_t kConvWithBiasLen = 4;
bool IsConvExtendNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D
|| type == schema::PrimitiveType_DeConv2D;
}
return false;
}
bool IsAddNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Add || type == schema::PrimitiveType_BiasAdd;
}
return false;
}
int Get_Kenrnel_nums(const CNodePtr &conv_node) {
MS_ASSERT(conv_node != nullptr);
auto value_primitive = conv_node->input(0);
auto value_node = value_primitive->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto value = value_node->value();
MS_ASSERT(value != nullptr);
auto primitive = value->cast<PrimitiveTValuePtr>();
MS_ASSERT(primitive != nullptr);
auto type = primitive->GetPrimitiveT()->value.type;
if (type == schema::PrimitiveType_Conv2D) {
return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut;
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier
* primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn;
} else if (type == schema::PrimitiveType_DeConv2D) {
return primitive->GetPrimitiveT()->value.AsDeConv2D()->channelOut;
} else {
MS_LOG(ERROR) << "Unsupported opType, " << type;
return 0;
}
}
void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
AnfNodePtr conv_bias_node = nullptr;
AnfNodePtr conv_weight_node = nullptr;
if (conv_node->inputs().size() == kConvNoBiasLen) {
conv_weight_node = conv_node->input(kConvWeightIndex);
} else if (conv_node->inputs().size() == kConvWithBiasLen) {
conv_weight_node = conv_node->input(kConvWeightIndex);
conv_bias_node = conv_node->input(kConvBiasIndex);
} else {
MS_LOG(EXCEPTION) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
}
auto kernel_nums = Get_Kenrnel_nums(conv_node);
if (kernel_nums <= 0) {
MS_LOG(EXCEPTION) << "kernel num less than 0";
}
auto add_bias_data = new(std::nothrow) float[kernel_nums];
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
CheckIfNodeIsParam(bias_add_weight);
auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param();
auto add_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(add_weight_param);
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->tensor_addr());
if (add_weight_tensor->tensor_shape().empty()) {
if (0 != memset_s(add_bias_data, kernel_nums * sizeof(float), *add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed";
}
} else {
if (0 != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed";
}
}
if (conv_bias_node != nullptr) {
CheckIfNodeIsParam(conv_bias_node);
auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
auto conv_bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_bias_param);
if (conv_bias_tensor->tensor_shape().empty() || conv_bias_tensor->tensor_shape()[0] != kernel_nums) {
MS_LOG(EXCEPTION) << "conv_bias_node shape error";
}
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->tensor_addr());
for (size_t i = 0; i < kernel_nums; i++) {
conv_bias_data[i] += add_bias_data[i];
}
delete[] add_bias_data;
} else {
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
auto conv_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param);
auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor);
conv_node->add_input(conv_new_bias);
}
}
} // namespace
const BaseRef ConvBiasaddFusion::DefinePattern() const {
MS_LOG(DEBUG) << "Enter pattern";
VarPtr X = std::make_shared<Var>();
VarPtr W = std::make_shared<Var>();
VarPtr B = std::make_shared<Var>();
CheckIfVarIsNull(X);
CheckIfVarIsNull(W);
CheckIfVarIsNull(B);
auto prim1 = new schema::PrimitiveT();
prim1->value.type = schema::PrimitiveType_BiasAdd;
auto prim11 = std::make_shared<lite::PrimitiveTValue>(prim1);
auto prim2 = new schema::PrimitiveT();
prim2->value.type = schema::PrimitiveType_Conv2D;
auto prim22 = std::make_shared<lite::PrimitiveTValue>(prim2);
return VectorRef({prim11, VectorRef({prim22, X, W}), B});
auto conv_var = std::make_shared<CondVar>(IsConvExtendNode);
auto add_var = std::make_shared<CondVar>(IsAddNode);
auto weight_var = std::make_shared<CondVar>(IsParamNode);
return VectorRef({add_var, conv_var, weight_var});
}
const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@ -50,24 +131,28 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
CheckIfFuncGraphIsNull(func_graph);
CheckIfAnfNodeIsNull(node);
auto cnode = node->cast<CNodePtr>();
CheckIfCNodeIsNull(cnode);
CheckInputSize(cnode, 3); // [op, conv_node, bias_node]
auto add_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(add_node);
CheckInputSize(add_node, kAddInputsLength);
AnfNodePtr conv_node_anf = cnode->input(1);
AnfNodePtr conv_node_anf = add_node->input(1);
CheckIfAnfNodeIsNull(conv_node_anf);
auto conv_node = conv_node_anf->cast<CNodePtr>();
CheckIfCNodeIsNull(conv_node);
CheckInputSize(conv_node, 3); // [op, X, W]
conv_node->add_input(cnode->input(2));
auto primitive = (lite::PrimitiveTValue *)(conv_node->input(0)->cast<ValueNodePtr>()->value().get());
primitive->GetPrimitiveT()->value.AsConv2D()->hasBias = true;
return conv_node_anf;
GenConvNewBias(func_graph, conv_node, add_node);
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
MS_ASSERT(primitiveT_value);
auto type = primitiveT_value->GetPrimitiveT()->value.type;
if (type == schema::PrimitiveType_Conv2D) {
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true;;
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true;
} else if (type == schema::PrimitiveType_DeConv2D) {
primitiveT_value->GetPrimitiveT()->value.AsDeConv2D()->hasBias = true;
} else {
MS_LOG(EXCEPTION) << "Unsupported opType, " << type;
}
return conv_node;
}
} // namespace opt
} // namespace mindspore
} // namespace mindspore::opt

View File

@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#include "src/gllo/common/optimizer.h"
#include "mindspore/lite/src/gllo/common/optimizer.h"
namespace mindspore {
namespace opt {

View File

@ -0,0 +1,141 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/lite/src/gllo/fusion/conv_bn_fusion.h"
#include <mindspore/lite/src/param_value_lite.h>
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
namespace mindspore::opt {
namespace {
constexpr size_t kCaffeBNMeanIndex = 2;
constexpr size_t kCaffeBNVarIndex = 3;
constexpr size_t kTFBNScaleIndex = 2;
constexpr size_t kTFBNBiasIndex = 3;
constexpr size_t kTFBNMeanIndex = 4;
constexpr size_t kTFBNVarIndex = 5;
constexpr const float EPS = 1e-8;
constexpr const float EPS_DEFAULT_FLOAT = 1e-5;
constexpr const float POW_NUM = 0.5;
bool IsBatchNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_CaffeBatchNorm || type == schema::PrimitiveType_FusedBatchNorm;
}
return false;
}
void CalTransale(const AnfNodePtr &bn_scale_node, const AnfNodePtr &bn_var_node, float *trans_scale, float eps,
int kernel_num) {
auto bn_var_param = bn_var_node->cast<ParameterPtr>()->default_param();
auto bn_var_tensor = std::dynamic_pointer_cast<ParamValueLite>(bn_var_param);
auto bn_var_data = reinterpret_cast<float *>(bn_var_tensor->tensor_addr());
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
if (memcpy_s(trans_scale, kernel_num * sizeof(float), bn_var_data, kernel_num * sizeof(float)) != 0) {
MS_LOG(EXCEPTION) << "memcpy_s transScale error";
return;
}
// 1/sqrt(variance + eps)
for (int32_t i = 0; i < kernel_num; i++) {
float tmp = trans_scale[i] + eps;
tmp = pow(tmp, POW_NUM);
trans_scale[i] = 1 / tmp;
}
if (bn_scale_node != nullptr) {
auto bn_scale_param = bn_scale_node->cast<ParameterPtr>()->default_param();
auto bn_scale_tensor = std::dynamic_pointer_cast<ParamValueLite>(bn_scale_param);
auto bn_scale_data = reinterpret_cast<float *>(bn_scale_tensor->tensor_addr());
// scale/sqrt(variance + eps)
for (int32_t i = 0; i < kernel_num; i++) {
trans_scale[i] *= bn_scale_data[i];
}
}
}
void CalTransBias(const AnfNodePtr &bn_mean_node, const AnfNodePtr &bn_bias_node, const float *trans_scale,
float *trans_bias, int kernel_num) {
auto bn_mean_param = bn_mean_node->cast<ParameterPtr>()->default_param();
auto bn_mean_tensor = std::dynamic_pointer_cast<ParamValueLite>(bn_mean_param);
auto bn_mean_data = reinterpret_cast<float *>(bn_mean_tensor->tensor_addr());
// cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps)
// -mean/sqrt(variance + eps)
for (int32_t i = 0; i < kernel_num; i++) {
trans_bias[i] = -bn_mean_data[i] * trans_scale[i];
}
if (bn_bias_node != nullptr) {
auto bn_bias_param = bn_bias_node->cast<ParameterPtr>()->default_param();
auto bn_bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(bn_bias_param);
auto bn_bias_data = reinterpret_cast<float *>(bn_bias_tensor->tensor_addr());
// -scale*mean/sqrt(variance + eps) + bias
for (int32_t i = 0; i < kernel_num; i++) {
trans_bias[i] += bn_bias_data[i];
}
}
}
} // namespace
const BaseRef ConvBatchNormFusion::DefinePattern() const {
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto bn_var = std::make_shared<CondVar>(IsBatchNode);
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
auto bn_other_var = std::make_shared<SeqVar>();
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});;
}
// BatchNorm weight Tensor definition:
// caffe
// estimated_mean --0
// estimated_variance --1
// tensorflow
// scale -- 0
// bias --1
// estimated_mean --2
// estimated_variance --3
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
float *trans_bias) const {
MS_ASSERT(bn_node != nullptr);
AnfNodePtr bn_mean_node = nullptr;
AnfNodePtr bn_variance_node = nullptr;
AnfNodePtr bn_scale_node = nullptr;
AnfNodePtr bn_bias_node = nullptr;
float eps = 0;
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(bn_node->input(0));
if (GetCNodeType(bn_node) == schema::PrimitiveType_CaffeBatchNorm) {
bn_mean_node = bn_node->input(kCaffeBNMeanIndex);
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
CheckIfNodeIsParam(bn_mean_node);
CheckIfNodeIsParam(bn_variance_node);
eps = primitiveT_value->GetPrimitiveT()->value.AsCaffeBatchNorm()->epsilon;
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
bn_scale_node = bn_node->input(kTFBNScaleIndex);
bn_bias_node = bn_node->input(kTFBNBiasIndex);
bn_mean_node = bn_node->input(kTFBNMeanIndex);
bn_variance_node = bn_node->input(kTFBNVarIndex);
eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon;
} else {
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
}
CheckIfNodeIsParam(bn_mean_node);
CheckIfNodeIsParam(bn_variance_node);
if (eps < EPS) {
eps = EPS_DEFAULT_FLOAT;
}
CalTransale(bn_scale_node, bn_variance_node, trans_scale, eps, kernel_num);
CalTransBias(bn_mean_node, bn_bias_node, trans_scale, trans_bias, kernel_num);
}
} // namespace mindspore::opt

View File

@ -0,0 +1,31 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
namespace mindspore::opt {
class ConvBatchNormFusion : public ConvTransformFusion {
public:
explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {}
~ConvBatchNormFusion() override = default;
const BaseRef DefinePattern() const override;
const void InitTransParam(const CNodePtr &, int, float *, float *) const override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_

View File

@ -14,78 +14,51 @@
* limitations under the License.
*/
#include "src/gllo/fusion/conv_scale_fusion.h"
#include "mindspore/lite/src/gllo/fusion/conv_scale_fusion.h"
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "src/param_value_lite.h"
#include "mindspore/lite/src/param_value_lite.h"
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "src/gllo/common/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "include/errorcode.h"
namespace mindspore {
namespace opt {
namespace mindspore::opt {
namespace {
constexpr size_t kScaleWeightIndex = 2;
constexpr size_t kScaleBiasIndex = 3;
constexpr size_t kScaleNoBiasLen = 3;
constexpr size_t kScaleWithBiasLen = 4;
bool IsScaleNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Scale;
}
return false;
}
} // namespace
const BaseRef ConvScaleFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
// conv2d inputs may be 2 or 3 inputs,match move to process
auto prim = new schema::PrimitiveT();
prim->value.type = schema::PrimitiveType_Scale;
auto prim_value = std::make_shared<lite::PrimitiveTValue>(prim);
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto bn_var = std::make_shared<CondVar>(IsScaleNode);
auto weight_var = std::make_shared<CondVar>(IsParamNode);
auto bias_var = std::make_shared<SeqVar>();
return VectorRef({prim_value, X});
return VectorRef({bn_var, conv_var, weight_var, bias_var});
}
const AnfNodePtr ConvScaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_LOG(DEBUG) << "conv activation pass process";
CheckIfFuncGraphIsNull(func_graph);
CheckIfAnfNodeIsNull(node);
auto scale_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(scale_node);
CheckInputSize(scale_node, 2);
AnfNodePtr pre_node = scale_node->input(1);
CheckIfAnfNodeIsNull(pre_node);
if (pre_node != nullptr && pre_node->isa<CNode>()) {
auto conv_node = pre_node->cast<CNodePtr>();
auto node_type = GetCNodeType(conv_node);
if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) {
return DoFusion(conv_node, scale_node);
}
}
return node;
}
const AnfNodePtr ConvScaleFusion::DoFusion(const CNodePtr &conv_node, const CNodePtr &scale_node) const {
if (scale_node->inputs().size() == 3) {
GetTransParam(scale_node->input(2), nullptr);
} else if (scale_node->inputs().size() == 4) {
// todo add bias fusion zhengjun10
GetTransParam(scale_node->input(2), scale_node->input(3));
const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale,
float *trans_bias) const {
MS_ASSERT(scale_node != nullptr);
AnfNodePtr scale_weight_node;
AnfNodePtr scale_bias_node;
if (scale_node->inputs().size() == kScaleNoBiasLen) {
scale_weight_node = scale_node->input(kScaleWeightIndex);
} else if (scale_node->inputs().size() == kScaleWithBiasLen) {
scale_weight_node = scale_node->input(kScaleWeightIndex);
scale_bias_node = scale_node->input(kScaleBiasIndex);
} else {
MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString();
return nullptr;
MS_LOG(EXCEPTION) << "Scale should has 2 or 3 input tensors, current inputs is" << scale_node->inputs().size();
}
AnfNodePtr conv_weight_node;
if (conv_node->inputs().size() == 3) {
conv_weight_node = conv_node->input(2);
} else {
MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString();
return nullptr;
}
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
auto weight_value = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param);
auto old_conv_weight = reinterpret_cast<const float *>(weight_value->tensor_addr());
auto new_conv_weight = new(std::nothrow) float[weight_value->tensor_shape_size()];
CalNewWeightTensor(old_conv_weight, new_conv_weight, weight_value->tensor_shape_size());
weight_value->set_tensor_addr(new_conv_weight);
return conv_node;
}
const lite::STATUS ConvScaleFusion::GetTransParam(const AnfNodePtr &scale_weight_node,
const AnfNodePtr &scale_bias_node) const {
if (!scale_weight_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "scale weight node not paramter node";
}
@ -96,31 +69,17 @@ const lite::STATUS ConvScaleFusion::GetTransParam(const AnfNodePtr &scale_weight
auto weight_value = std::dynamic_pointer_cast<ParamValueLite>(scale_weight_param);
auto weight_data = reinterpret_cast<const float *>(weight_value->tensor_addr());
if (0 != memcpy_s(trans_scale, kernel_nums * sizeof(float), weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s transScale failed";
return lite::RET_ERROR;
if (0 != memcpy_s(trans_scale, kernel_num * sizeof(float), weight_data, kernel_num * sizeof(float))) {
MS_LOG(EXCEPTION) << "memcpy_s transScale failed";
}
return lite::RET_OK;
}
const lite::STATUS ConvScaleFusion::CalNewWeightTensor(const float *oldWeightTensor, float *newWeightTensor,
const size_t tensor_shape_size) const {
MS_ASSERT(oldWeightTensor != nullptr);
if (0 != memset_s(newWeightTensor, tensor_shape_size * sizeof(float), 0, tensor_shape_size * sizeof(float))) {
MS_LOG(ERROR) << "memset newWeightData failed";
return lite::RET_ERROR;
}
if (kernel_nums == 0) {
MS_LOG(ERROR) << "kernel nums is 0";
return lite::RET_ERROR;
}
auto kernel_size = tensor_shape_size / kernel_nums;
for (size_t i = 0; i < kernel_nums; i++) {
for (size_t j = 0; j < kernel_size; j++) {
newWeightTensor[i * kernel_size + j] = oldWeightTensor[i * kernel_size + j] * trans_scale[i];
if (scale_bias_node != nullptr) {
auto scale_bias_param = scale_bias_node->cast<ParameterPtr>()->default_param();
auto bias_value = std::dynamic_pointer_cast<ParamValueLite>(scale_bias_param);
auto bias_data = reinterpret_cast<const float *>(bias_value->tensor_addr());
if (0 != memcpy_s(trans_bias, kernel_num * sizeof(float), bias_data, kernel_num * sizeof(float))) {
MS_LOG(EXCEPTION) << "memcpy_s transScale failed";
}
}
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore
} // namespace mindspore::opt

View File

@ -17,24 +17,15 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#include "src/gllo/common/optimizer.h"
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
namespace mindspore {
namespace opt {
class ConvScaleFusion : public PatternProcessPass {
namespace mindspore::opt {
class ConvScaleFusion : public ConvTransformFusion {
public:
explicit ConvScaleFusion(bool multigraph = true) : PatternProcessPass("conv_scale_fusion", multigraph) {}
explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {}
~ConvScaleFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
const AnfNodePtr DoFusion(const CNodePtr &, const CNodePtr &) const;
const lite::STATUS GetTransParam(const AnfNodePtr &, const AnfNodePtr &) const;
const lite::STATUS CalNewWeightTensor(const float *, float *, const size_t) const;
private:
float *trans_scale = nullptr;
int kernel_nums = 0;
const void InitTransParam(const CNodePtr &, int, float *, float *) const override;
};
} // namespace opt
} // namespace mindspore
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_

View File

@ -0,0 +1,201 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
#include <memory>
#include "mindspore/lite/src/param_value_lite.h"
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "include/errorcode.h"
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
constexpr size_t kConvBiasIndex = 3;
constexpr size_t kConvNoBiasLen = 3;
constexpr size_t kConvWithBiasLen = 4;
int Get_Kenrnel_nums(const CNodePtr &conv_node) {
MS_ASSERT(conv_node != nullptr);
auto value_primitive = conv_node->input(0);
auto value_node = value_primitive->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto value = value_node->value();
MS_ASSERT(value != nullptr);
auto primitive = value->cast<PrimitiveTValuePtr>();
MS_ASSERT(primitive != nullptr);
auto type = primitive->GetPrimitiveT()->value.type;
if (type == schema::PrimitiveType_Conv2D) {
return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut;
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier
* primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn;
} else {
MS_LOG(ERROR) << "Unsupported opType, " << type;
return 0;
}
}
} // namespace
const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_LOG(DEBUG) << "conv activation pass process";
CheckIfFuncGraphIsNull(func_graph);
CheckIfAnfNodeIsNull(node);
// transform node means scale,bn
auto transform_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(transform_node);
CheckLeastInputSize(transform_node, 2);
auto pre_node = transform_node->input(1);
auto conv_node = pre_node->cast<CNodePtr>();
int kernel_nums = Get_Kenrnel_nums(conv_node);
if (kernel_nums <= 0) {
MS_LOG(ERROR) << "Unsupported conv node, " << conv_node->DebugString();
return node;
}
auto trans_scale = new(std::nothrow) float[kernel_nums];
auto trans_bias = new(std::nothrow) float[kernel_nums];
GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias);
delete[] trans_bias;
delete[] trans_scale;
return pre_node;
}
const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums,
float *trans_scale, float *trans_bias) const {
if (trans_scale == nullptr) {
MS_LOG(EXCEPTION) << "new transScale failed";
}
if (trans_bias == nullptr) {
MS_LOG(EXCEPTION) << "new transBias failed";
}
if (0 != memset_s(trans_scale, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float))) {
MS_LOG(EXCEPTION) << "memset transScale failed";
}
if (0 != memset_s(trans_bias, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float))) {
MS_LOG(EXCEPTION) << "memset transBias failed";
}
InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
}
const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node,
int kernel_num, const float *trans_scale, const float *trans_bias)
const {
MS_ASSERT(conv_node != nullptr);
AnfNodePtr conv_weight_node = nullptr;
AnfNodePtr conv_bias_node = nullptr;
if (conv_node->inputs().size() == kConvNoBiasLen) {
conv_weight_node = conv_node->input(kConvWeightIndex);
} else if (conv_node->inputs().size() == kConvWithBiasLen) {
conv_weight_node = conv_node->input(kConvWeightIndex);
conv_bias_node = conv_node->input(kConvBiasIndex);
} else {
MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
return;
}
if (!conv_weight_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "scale weight node not paramter node";
}
if (conv_bias_node != nullptr && !conv_bias_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "scale bias node not paramter node";
}
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
auto weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param);
auto weight_data = reinterpret_cast<float *>(weight_tensor->tensor_addr());
if (kernel_num <= 0) {
MS_LOG(EXCEPTION) << "kernel num less than 0";
}
auto kernel_size = weight_tensor->tensor_shape_size() / kernel_num;
CalNewWeightTensor(weight_data, kernel_num, kernel_size, trans_scale);
float *bias_data = nullptr;
// conv has bias,bias_flag true
bool bias_flag = false;
if (conv_bias_node != nullptr) {
auto bias_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
auto bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(bias_weight_param);
bias_data = reinterpret_cast<float *>(bias_tensor->tensor_addr());
bias_flag = true;
} else {
bias_data = new(std::nothrow) float[kernel_num];
}
CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias);
if (!bias_flag) {
auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor);
conv_node->add_input(bias_node);
}
}
const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size,
const float *trans_scale) const {
MS_ASSERT(weight_data != nullptr);
auto tmp_weight_data = new(std::nothrow) float[kernel_num * kernel_size];
MS_ASSERT(new_weight_data != nullptr);
auto data_size = kernel_num * kernel_size * sizeof(float);
if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) {
MS_LOG(EXCEPTION) << "memset newWeightData failed";
return;
}
for (size_t i = 0; i < kernel_num; i++) {
for (size_t j = 0; j < kernel_size; j++) {
tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i];
}
}
auto ret = memcpy_s(weight_data, data_size, tmp_weight_data, data_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
}
delete[] tmp_weight_data;
}
const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag,
const float *trans_scale, const float *trans_bias) const {
MS_ASSERT(bias_data != nullptr);
if (bias_flag) {
auto tmp_bias_data = new(std::nothrow) float[kernel_num];
if (0 != memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) {
MS_LOG(EXCEPTION) << "memset bias data failed";
}
for (size_t i = 0; i < kernel_num; i++) {
tmp_bias_data[i] = bias_data[i] * trans_scale[i] + trans_bias[i];
}
auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), tmp_bias_data, kernel_num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
}
delete[] tmp_bias_data;
} else {
if (0 != memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) {
MS_LOG(EXCEPTION) << "memset bias data failed";
}
auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), trans_bias, kernel_num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
}
}
}
} // namespace mindspore::opt

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*conv_activation_fusion.h
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
#include <string>
#include "mindspore/lite/src/gllo/common/optimizer.h"
namespace mindspore::opt {
class ConvTransformFusion : public PatternProcessPass {
public:
explicit ConvTransformFusion(bool multigraph = true, const std::string &name = "conv_transform_fusion")
: PatternProcessPass(name, multigraph) {}
~ConvTransformFusion() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
const void GenTransParam(const CNodePtr &, int, float *, float *) const;
virtual const void InitTransParam(const CNodePtr &, int, float *, float *) const = 0;
const void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
const void CalNewWeightTensor(float *, int, int, const float *) const;
const void CalNewBiasTensor(float *, int, bool, const float *, const float *) const;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_

View File

@ -199,7 +199,9 @@ if(BUILD_CONVERTER)
${LITE_DIR}/src/gllo/common/utils.cc
${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_transform_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_scale_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_bn_fusion.cc
)
endif()
### train

View File

@ -79,7 +79,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_transform_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_scale_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_bn_fusion.cc
)
add_subdirectory(parser/caffe)

View File

@ -18,8 +18,10 @@
#include <memory>
#include <string>
#include "utils/log_adapter.h"
#include "src/gllo/fusion/conv_biasadd_fusion.h"
#include "mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h"
#include "mindspore/lite/src/gllo/fusion/conv_activation_fusion.h"
#include "mindspore/lite/src/gllo/fusion/conv_scale_fusion.h"
#include "mindspore/lite/src/gllo/fusion/conv_bn_fusion.h"
using std::string;
namespace mindspore {
@ -34,8 +36,13 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) {
// return old_graph;
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::ConvBiasaddFusion>();
pm->AddPass(pass);
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation,
schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,
schema::ActivationType_RELU6));
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
return new_graph;