!13521 instance_norm_fusion & onnx_layernorm_fusion

From: @wangzhe128
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-03-18 17:16:20 +08:00 committed by Gitee
commit 4d184cdbb1
17 changed files with 367 additions and 161 deletions

View File

@ -41,6 +41,20 @@ inline void Int32ToFloat32(const int32_t *input, float *output, int number) {
}
}
inline void Int64ToFloat32(const int64_t *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float)input[i];
}
}
#ifdef ENABLE_FP16
inline void Int64ToFp16(const int64_t *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float16_t)input[i];
}
}
#endif
inline void Fp16ToFloat32(const uint16_t *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = ShortToFloat32(input[i]);
@ -82,6 +96,7 @@ inline void BoolToInt32(const bool *input, int32_t *output, int number) {
output[i] = (int32_t)input[i];
}
}
#ifdef __cplusplus
}
#endif

View File

@ -36,7 +36,8 @@ int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
}
if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 &&
input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 &&
input->data_type_ != kNumberTypeFloat32 && input->data_type_ != kNumberTypeFloat16) {
input->data_type_ != kNumberTypeInt64 && input->data_type_ != kNumberTypeFloat32 &&
input->data_type_ != kNumberTypeFloat16) {
return NNACL_INPUT_TENSOR_ERROR;
}

View File

@ -121,6 +121,17 @@ int CastFp16CPUKernel::DoCast(int thread_id) {
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
return RET_ERROR;
}
} else if (input_data_type == kNumberTypeInt64) {
switch (output_data_type) {
case kNumberTypeFloat16:
Int64ToFloat32(reinterpret_cast<int64_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR;
@ -136,4 +147,5 @@ int CastFp16CPUKernel::Run() {
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>)
} // namespace mindspore::kernel

View File

@ -53,6 +53,37 @@ int CastCPUKernel::ReSize() {
return RET_OK;
}
int CastCPUKernel::CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num) {
auto input_data_type = input->data_type();
auto output_data = output->data_c();
switch (input_data_type) {
case kNumberTypeBool:
BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeUInt8:
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeInt32:
Int32ToFloat32(reinterpret_cast<int32_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeFloat16:
Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeInt64:
Int64ToFloat32(reinterpret_cast<int64_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR;
}
return RET_OK;
}
int CastCPUKernel::DoCast(int thread_id) {
auto input = in_tensors_.at(0);
int data_num = MSMIN(stride_, data_num_ - thread_id * stride_);
@ -91,32 +122,17 @@ int CastCPUKernel::DoCast(int thread_id) {
} else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) {
BoolToInt32(reinterpret_cast<bool *>(input->data_c()) + offset, reinterpret_cast<int32_t *>(output_data) + offset,
data_num);
#ifdef ENABLE_FP16
} else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeFloat16) {
Int64ToFp16(reinterpret_cast<int64_t *>(input->data_c()) + offset,
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
#endif
} else {
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type;
return RET_ERROR;
}
} else {
switch (input_data_type) {
case kNumberTypeBool:
BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeUInt8:
Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeInt32:
Int32ToFloat32(reinterpret_cast<int32_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
case kNumberTypeFloat16:
Fp16ToFloat32(reinterpret_cast<uint16_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR;
}
return CastToFp32(input, output, offset, data_num);
}
return RET_OK;
}
@ -132,6 +148,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, LiteKernelCreator<CastC
REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>)
#ifndef ENABLE_ARM
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastCPUKernel>)

View File

@ -38,6 +38,7 @@ class CastCPUKernel : public LiteKernel {
int DoCast(int thread_id);
private:
int CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num);
int stride_;
int data_num_;
};

View File

@ -233,7 +233,8 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/layer_norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_layer_norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc

View File

@ -38,7 +38,7 @@ ml_video_edit_img_segment 1
ml_video_edit_video_segment_gauss_adaptis_part1 2
ml_video_edit_generate_filter.pb 1
ml_video_edit_img_segment_adaptise.pb 0.5 2
ml_video_edit_video_segment_gauss_adaptis_part2.pb 3 2
ml_video_edit_video_segment_gauss_adaptis_part2.pb 10 2
ml_video_edit_person_divison_pic 0.5
ml_video_edit_person_divison_video 13 2
ml_video_edit_imitate_filter.onnx 230

View File

@ -51,6 +51,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveT
schema::PrimitiveType_SpaceToBatch,
schema::PrimitiveType_SpaceToBatchND};
static const std::vector<schema::PrimitiveType> nchwOpList = {schema::PrimitiveType_InstanceNorm};
static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad,
schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion,
@ -153,6 +155,8 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList;
std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
std::vector<schema::PrimitiveType> GetNchwOpList() { return nchwOpList; }
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; }
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; }

View File

@ -60,6 +60,8 @@ std::vector<schema::PrimitiveType> GetInsertOpList();
std::vector<schema::PrimitiveType> GetNhwcOpList();
std::vector<schema::PrimitiveType> GetNchwOpList();
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();

View File

@ -44,7 +44,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_tuplegetitem_fusion.cc
../optimizer/fusion/constant_folding_fusion.cc
../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/fusion/layer_norm_fusion.cc
../optimizer/fusion/tf_norm_fusion.cc
../optimizer/fusion/onnx_layer_norm_fusion.cc
../optimizer/fusion/batchmatmul_fusion.cc
../optimizer/fusion/sigmoid_mul_fusion.cc
../optimizer/fusion/conv_conv_fusion.cc

View File

@ -27,7 +27,8 @@
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/fusion/layer_norm_fusion.h"
#include "tools/optimizer/fusion/tf_norm_fusion.h"
#include "tools/optimizer/fusion/onnx_layer_norm_fusion.h"
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/fusion/conv_conv_fusion.h"
@ -77,7 +78,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
auto conv_scale_pass = std::make_shared<opt::ConvScaleFusion>();
conv_scale_pass->SetFmkType(config->fmk);
fusion_pm->AddPass(conv_scale_pass);
fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());

View File

@ -48,7 +48,12 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) {
STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType,
FormatTransNodeType *afterNodeType) {
if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc
return RET_NO_CHANGE;
if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
*beforeNodeType = kNHWC2NCHW;
*afterNodeType = kNCHW2NHWC;
return RET_OK;
} else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS ||
fmk_type_ == converter::FmkType_ONNX) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
@ -63,6 +68,11 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT
*afterNodeType = kNHWC2NCHW;
return RET_OK;
}
if (IsContain(GetNchwOpList(), GetCNodeTType(node))) {
*beforeNodeType = kNHWC2NCHW;
*afterNodeType = kNCHW2NHWC;
return RET_OK;
}
return RET_NO_CHANGE;
}
MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_;

View File

@ -1,65 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
class LayerNormFusion : public PatternProcessPass {
public:
explicit LayerNormFusion(const std::string &name = "layer_norm_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {
input_ = std::make_shared<Var>();
mean1_ = std::make_shared<Var>();
mean1_axes_ = std::make_shared<Var>();
mean2_ = std::make_shared<Var>();
mean2_axes_ = std::make_shared<Var>();
gamma_ = std::make_shared<Var>();
beta_ = std::make_shared<Var>();
epsilon_ = std::make_shared<Var>();
}
~LayerNormFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
bool GetAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, const std::vector<int> &params_shape,
int *begin_norm_axis, int *begin_params_axis) const;
bool CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, int *begin_params_axis) const;
CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon,
int begin_norm_axis, int begin_params_axis) const;
VarPtr input_ = nullptr;
VarPtr mean1_ = nullptr;
VarPtr mean1_axes_ = nullptr;
VarPtr mean2_ = nullptr;
VarPtr mean2_axes_ = nullptr;
VarPtr gamma_ = nullptr;
VarPtr beta_ = nullptr;
VarPtr epsilon_ = nullptr;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/onnx_layer_norm_fusion.h"
#include <memory>
#include "ops/rsqrt.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
const BaseRef OnnxLayerNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()});
VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_});
VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref});
VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref});
VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref});
VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_});
return add2_ref;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "tools/optimizer/fusion/tf_norm_fusion.h"
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
class OnnxLayerNormFusion : public TfNormFusion {
public:
explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true)
: TfNormFusion(name, multigraph) {}
~OnnxLayerNormFusion() override = default;
const BaseRef DefinePattern() const override;
};
inline bool IsPowNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion);
}
return false;
}
inline bool IsSqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt);
}
return false;
}
inline bool IsDivNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) ||
CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion"));
}
return false;
}
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_

View File

@ -13,11 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/layer_norm_fusion.h"
#include "tools/optimizer/fusion/tf_norm_fusion.h"
#include <memory>
#include "ops/fusion/layer_norm_fusion.h"
#include "ops/fusion/reduce_fusion.h"
#include "ops/rsqrt.h"
#include "mindspore/core/ops/instance_norm.h"
#include "src/param_value_lite.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
@ -26,41 +27,6 @@
namespace mindspore {
namespace opt {
namespace {
bool IsAddNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion);
}
return false;
}
bool IsSquaredDifferenceNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference);
}
return false;
}
bool IsRsqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt);
}
return false;
}
bool IsMulNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion);
}
return false;
}
bool IsSubNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion);
}
return false;
}
lite::STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
MS_ASSERT(node != nullptr);
if (utils::isa<ParameterPtr>(n)) {
@ -106,7 +72,7 @@ bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr
}
} // namespace
const BaseRef LayerNormFusion::DefinePattern() const {
const BaseRef TfNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode);
VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref});
@ -128,13 +94,26 @@ const BaseRef LayerNormFusion::DefinePattern() const {
return add2_ref;
}
CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon,
int begin_norm_axis, int begin_params_axis) const {
CNodePtr TfNormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const schema::PrimitiveType type, float epsilon, int begin_norm_axis,
int begin_params_axis) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>();
layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon);
auto value_node = NewValueNode(layer_norm_primitive);
auto norm_primitive = std::make_unique<schema::PrimitiveT>();
norm_primitive->value.type = type;
PrimitiveCPtr primitive = nullptr;
if (type == schema::PrimitiveType_LayerNormFusion) {
auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>();
layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true);
primitive = layer_norm_primitive;
} else if (type == schema::PrimitiveType_InstanceNorm) {
auto instance_norm_primitive = std::make_shared<ops::InstanceNorm>();
instance_norm_primitive->Init(epsilon);
primitive = instance_norm_primitive;
} else {
return nullptr;
}
auto value_node = NewValueNode(primitive);
std::vector<AnfNodePtr> new_node_inputs = {value_node};
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
MS_ASSERT(input_node != nullptr);
@ -149,10 +128,11 @@ CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, co
return new_node;
}
bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes,
const std::vector<int> &params_shape, int *begin_norm_axis,
int *begin_params_axis) const {
bool TfNormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes,
const std::vector<int> &params_shape, schema::PrimitiveType *type,
int *begin_norm_axis, int *begin_params_axis) const {
MS_ASSERT(input_node != nullptr);
MS_ASSERT(type != nullptr);
MS_ASSERT(begin_norm_axis != nullptr);
MS_ASSERT(begin_params_axis != nullptr);
auto abstract = input_cnode->abstract();
@ -170,30 +150,44 @@ bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector<int
return false;
}
auto shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (mean_axes.back() + 1 != static_cast<int>(shape.size())) {
MS_LOG(DEBUG) << "mean node is not reduce to last axis";
return false;
}
for (size_t i = 1; i < mean_axes.size(); ++i) {
if (mean_axes[i] != mean_axes[i - 1] + 1) {
MS_LOG(DEBUG) << "mean axes is not continuous";
return false;
}
}
// there is no need to check params_shape
*begin_norm_axis = mean_axes.front();
*begin_params_axis = static_cast<int>(shape.size()) - static_cast<int>(params_shape.size());
if (*begin_params_axis < 0) {
MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse";
if (shape.size() == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) {
if (params_shape.size() == 1 && params_shape.back() == shape.back()) {
*type = schema::PrimitiveType_InstanceNorm;
return true;
}
}
if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) {
MS_LOG(DEBUG) << "mean node is not reduce to last axis";
return false;
}
// there is no need to check params_shape
*begin_norm_axis = mean_axes.front();
if (*begin_norm_axis >= 0) {
*begin_params_axis = static_cast<int>(shape.size()) - static_cast<int>(params_shape.size());
if (*begin_params_axis < 0) {
MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse";
return false;
}
} else {
*begin_params_axis = -static_cast<int>(params_shape.size());
}
*type = schema::PrimitiveType_LayerNormFusion;
return true;
}
bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis,
int *begin_params_axis) const {
bool TfNormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon,
int *begin_norm_axis, int *begin_params_axis) const {
MS_ASSERT(equiv != nullptr);
MS_ASSERT(epsilon != nullptr);
MS_ASSERT(type != nullptr);
MS_ASSERT(begin_norm_axis != nullptr);
MS_ASSERT(begin_params_axis != nullptr);
// beta
@ -243,9 +237,6 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b
if (mean1_axes != mean2_axes) {
return false;
}
if (mean1_axes.size() != gamma_shape.size() || mean1_axes.size() != beta_shape.size()) {
return false;
}
if (gamma_shape != beta_shape) {
return false;
}
@ -254,14 +245,14 @@ bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *b
} else {
return false;
}
if (!GetAxis(input_cnode, mean1_axes, gamma_shape, begin_norm_axis, begin_params_axis)) {
if (!GetNormTypeAndAxis(input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis)) {
return false;
}
return true;
}
const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const AnfNodePtr TfNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(equiv != nullptr);
@ -273,14 +264,24 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const
float epsilon = 0.0f;
int begin_norm_axis = 0;
int begin_params_axis = 0;
if (!CheckPattern(equiv, &epsilon, &begin_norm_axis, &begin_params_axis)) {
schema::PrimitiveType type = schema::PrimitiveType_NONE;
if (!CheckPattern(equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) {
return nullptr;
}
auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, epsilon, begin_norm_axis, begin_params_axis);
layer_norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success";
return layer_norm_cnode;
auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis);
if (norm_cnode == nullptr) {
MS_LOG(DEBUG) << "create norm cnode failed";
return nullptr;
}
norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
if (type == schema::PrimitiveType_LayerNormFusion) {
norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
MS_LOG(INFO) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
} else if (type == schema::PrimitiveType_InstanceNorm) {
norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope());
MS_LOG(INFO) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
}
return norm_cnode;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,107 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "schema/inner/model_generated.h"
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
/// fuse layer_norm, instance_norm into one operator
class TfNormFusion : public PatternProcessPass {
public:
explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {
input_ = std::make_shared<Var>();
mean1_ = std::make_shared<Var>();
mean1_axes_ = std::make_shared<Var>();
mean2_ = std::make_shared<Var>();
mean2_axes_ = std::make_shared<Var>();
gamma_ = std::make_shared<Var>();
beta_ = std::make_shared<Var>();
epsilon_ = std::make_shared<Var>();
}
~TfNormFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
bool GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes,
const std::vector<int> &params_shape, schema::PrimitiveType *type, int *begin_norm_axis,
int *begin_params_axis) const;
bool CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, int *begin_norm_axis,
int *begin_params_axis) const;
CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type,
float epsilon, int begin_norm_axis, int begin_params_axis) const;
protected:
VarPtr input_ = nullptr;
VarPtr mean1_ = nullptr;
VarPtr mean1_axes_ = nullptr;
VarPtr mean2_ = nullptr;
VarPtr mean2_axes_ = nullptr;
VarPtr gamma_ = nullptr;
VarPtr beta_ = nullptr;
VarPtr epsilon_ = nullptr;
};
inline bool IsAddNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion);
}
return false;
}
inline bool IsSquaredDifferenceNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference);
}
return false;
}
inline bool IsRsqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt);
}
return false;
}
inline bool IsMulNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion);
}
return false;
}
inline bool IsSubNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion);
}
return false;
}
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_