forked from mindspore-Ecosystem/mindspore
add shape_size infer for layernorm-fusion
This commit is contained in:
parent
a93031498c
commit
0721040ea9
|
@ -86,7 +86,7 @@ ml_video_edit_video_segment_gauss_adaptis_part2.pb;2
|
|||
#encoder_0111.pb is the same model as ml_tts_encoder.pb.
|
||||
#encoder_0111.pb;4;1:1,44:1:1
|
||||
encoder_201228.pb;3;1,22:1:1;;input_dependent
|
||||
ml_video_edit_oneclick_adaptis.pb;3:2,1,3
|
||||
ml_video_edit_oneclick_adaptis.pb;3
|
||||
tacotron_encoder_stf.pb;5;1,62:1,62:1,62:1,62:1;;input_dependent
|
||||
female_model_step2_int16_noiseout.pb;66
|
||||
ml_female_model_step6_noiseout.pb;66
|
||||
|
|
|
@ -87,7 +87,7 @@ ml_tts_vocoder.pb;66 53
|
|||
hiai_transformer_encoder.pb;15 4
|
||||
decoder_step_nocumsum_v5.pb;13;1,512:1,512:1,512:1,512:1,512:1,127,320:1,1429,2:1,127:1:1,127:1,512:1,80:1,127 1.2
|
||||
hiai_nlu_model_multi.pb;6;1,32:1,32:1,32:1,74:1,11:1,6
|
||||
hiai_nlu_model_single.pb;3;1,32:1,32:1,32 540
|
||||
hiai_nlu_model_single.pb;3;1,32:1,32:1,32 4.4
|
||||
fsr_270_mindspore.pb 6.0
|
||||
fsr_360_mindspore.pb 6.5
|
||||
fsr_720_mindspore.pb 2.0
|
||||
|
|
|
@ -136,7 +136,7 @@ function Run_Benchmark() {
|
|||
model_info=`echo ${line_info}|awk -F ' ' '{print $1}'`
|
||||
spec_acc_limit=`echo ${line_info}|awk -F ' ' '{print $2}'`
|
||||
model_name=`echo ${model_info}|awk -F ';' '{print $1}'`
|
||||
input_config=`echo ${model_info} | awk -F ';' '{print $2}'`
|
||||
input_num=`echo ${model_info} | awk -F ';' '{print $2}'`
|
||||
input_shapes=`echo ${model_info} | awk -F ';' '{print $3}'`
|
||||
spec_threads=`echo ${model_info} | awk -F ';' '{print $4}'`
|
||||
extra_info=`echo ${model_info} | awk -F ';' '{print $5}'`
|
||||
|
@ -173,24 +173,13 @@ function Run_Benchmark() {
|
|||
input_files=""
|
||||
output_file=""
|
||||
data_path=$3"/input_output/"
|
||||
if [[ ${input_config} == "" || ${input_config} == 1 ]]; then
|
||||
if [[ ${input_num} == "" || ${input_num} == 1 ]]; then
|
||||
input_files=${data_path}'input/'${model_name}'.ms.bin'
|
||||
else
|
||||
input_num=`echo ${input_config} | awk -F ':' '{print $1}'`
|
||||
input_seq=`echo ${input_config} | awk -F ':' '{print $2}'`
|
||||
if [[ ${input_seq} == "" ]]; then
|
||||
for i in $(seq 1 $input_num)
|
||||
do
|
||||
input_files=${input_files}${data_path}'input/'${model_name}'.ms.bin_'$i','
|
||||
done
|
||||
else
|
||||
for i in $(seq 1 $input_num)
|
||||
do
|
||||
cur_input_num=${input_seq%%,*}
|
||||
input_seq=${input_seq#*,}
|
||||
input_files=${input_files}${data_path}'input/oldinput_'${model_name}'.ms.bin_'$cur_input_num','
|
||||
done
|
||||
fi
|
||||
for i in $(seq 1 $input_num)
|
||||
do
|
||||
input_files=${input_files}${data_path}'input/'${model_name}'.ms.bin_'$i','
|
||||
done
|
||||
fi
|
||||
output_file=${data_path}'output/'${model_name}'.ms.out'
|
||||
# adjust threads
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "backend/optimizer/common/helper.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -591,6 +592,31 @@ bool IsParamNode(const BaseRef &n) {
|
|||
return tensor->data_c() != nullptr;
|
||||
}
|
||||
|
||||
STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) {
|
||||
CHECK_NULL_RETURN(tensor_info);
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(WARNING) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr, infershape is delayed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
|
||||
MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
|
||||
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
|
||||
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
*tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack());
|
||||
if (*tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool IsParamOrValueNodeWithData(const BaseRef &n) {
|
||||
if (utils::isa<ValueNode>(n)) {
|
||||
auto value_node = utils::cast<ValueNodePtr>(n);
|
||||
|
|
|
@ -113,6 +113,8 @@ CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &inp
|
|||
|
||||
STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVector *shape);
|
||||
|
||||
STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index);
|
||||
|
||||
template <const PrimitivePtr *prim = nullptr>
|
||||
inline bool IsSpecifiedNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -115,10 +116,11 @@ CNodePtr NormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivP
|
|||
return new_node;
|
||||
}
|
||||
|
||||
bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes,
|
||||
const std::vector<int> ¶ms_shape, schema::PrimitiveType *type,
|
||||
int *begin_norm_axis, int *begin_params_axis) const {
|
||||
MS_ASSERT(input_cnode != nullptr);
|
||||
bool NormFusion::GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
|
||||
const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape,
|
||||
schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
MS_ASSERT(type != nullptr);
|
||||
MS_ASSERT(begin_norm_axis != nullptr);
|
||||
MS_ASSERT(begin_params_axis != nullptr);
|
||||
|
@ -137,19 +139,27 @@ bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vect
|
|||
return false;
|
||||
}
|
||||
auto shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
int shape_size = static_cast<int>(shape.size());
|
||||
if (shape.empty()) {
|
||||
auto shape_size_map = ShapeSizeInfer(func_graph);
|
||||
if (shape_size_map.find(input_cnode->fullname_with_scope()) != shape_size_map.end()) {
|
||||
shape_size = shape_size_map[input_cnode->fullname_with_scope()];
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < mean_axes.size(); ++i) {
|
||||
if (mean_axes[i] != mean_axes[i - 1] + 1) {
|
||||
MS_LOG(DEBUG) << "mean axes is not continuous";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (shape.size() == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) {
|
||||
if (shape_size == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) {
|
||||
if (params_shape.size() == 1 && params_shape.back() == shape.back()) {
|
||||
*type = schema::PrimitiveType_InstanceNorm;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) {
|
||||
if (mean_axes.back() >= 0 && mean_axes.back() + 1 != shape_size) {
|
||||
MS_LOG(DEBUG) << "mean node is not reduce to last axis.";
|
||||
return false;
|
||||
}
|
||||
|
@ -157,7 +167,7 @@ bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vect
|
|||
// there is no need to check params_shape
|
||||
*begin_norm_axis = mean_axes.front();
|
||||
if (*begin_norm_axis >= 0) {
|
||||
*begin_params_axis = static_cast<int>(shape.size()) - static_cast<int>(params_shape.size());
|
||||
*begin_params_axis = shape_size - static_cast<int>(params_shape.size());
|
||||
if (*begin_params_axis < 0) {
|
||||
MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse";
|
||||
return false;
|
||||
|
@ -170,8 +180,8 @@ bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vect
|
|||
return true;
|
||||
}
|
||||
|
||||
bool NormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, int *begin_norm_axis,
|
||||
int *begin_params_axis) const {
|
||||
bool NormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type,
|
||||
float *epsilon, int *begin_norm_axis, int *begin_params_axis) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(epsilon != nullptr);
|
||||
MS_ASSERT(type != nullptr);
|
||||
|
@ -243,10 +253,176 @@ bool NormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type
|
|||
} else {
|
||||
return false;
|
||||
}
|
||||
if (!GetNormTypeAndAxis(input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis)) {
|
||||
return false;
|
||||
return GetNormTypeAndAxis(func_graph, input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis);
|
||||
}
|
||||
|
||||
namespace {
|
||||
int CommonShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 0);
|
||||
return in_shape_size.at(0);
|
||||
}
|
||||
|
||||
int ExpandDimsShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 0);
|
||||
return in_shape_size.at(0) + 1;
|
||||
}
|
||||
|
||||
int StridedSliceShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 0);
|
||||
auto new_axis_mask = primitive.value.AsStridedSlice()->new_axis_mask;
|
||||
auto add_dims = 0;
|
||||
while (new_axis_mask != 0) {
|
||||
new_axis_mask = (new_axis_mask - 1) & new_axis_mask;
|
||||
add_dims++;
|
||||
}
|
||||
return true;
|
||||
return in_shape_size.at(0) + add_dims;
|
||||
}
|
||||
|
||||
int MatMulShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 1);
|
||||
return in_shape_size[0];
|
||||
}
|
||||
|
||||
int ReShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 1);
|
||||
return in_shape_size[1];
|
||||
}
|
||||
|
||||
int StackSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 1);
|
||||
return std::accumulate(in_shape_size.begin(), in_shape_size.end(), 0);
|
||||
}
|
||||
|
||||
int SqueezeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 0);
|
||||
auto axis = primitive.value.AsSqueeze()->axis;
|
||||
if (axis.empty()) {
|
||||
return 0;
|
||||
}
|
||||
return in_shape_size.at(0) - axis.size();
|
||||
}
|
||||
|
||||
int OneHotSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 0);
|
||||
return in_shape_size.at(0) + 1;
|
||||
}
|
||||
|
||||
int FillShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 1);
|
||||
return in_shape_size.at(1);
|
||||
}
|
||||
|
||||
int ShapeOpSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) { return 1; }
|
||||
|
||||
int BroadcastShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 1);
|
||||
int result = 0;
|
||||
for (auto shape_size : in_shape_size) {
|
||||
result = std::max(result, shape_size);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::map<string, int> NormFusion::ShapeSizeInfer(const FuncGraphPtr &func_graph) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
std::map<string, int> node_shape_size;
|
||||
std::map<string, std::vector<int>> node_shape;
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto origin_primc = GetValueNode<PrimitiveCPtr>(cnode->input(0));
|
||||
auto prim_t = lite::GetPrimitiveT(cnode->input(0));
|
||||
if (prim_t == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto prim_type = prim_t->value.type;
|
||||
auto shape_size_infer_iter = shape_size_infer_registry_.find(prim_type);
|
||||
if (shape_size_infer_iter == shape_size_infer_registry_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// specific op infer shape
|
||||
if (prim_type == schema::PrimitiveType_Shape) {
|
||||
tensor::TensorPtr tensor_info;
|
||||
auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, 1);
|
||||
if (ret == RET_OK) {
|
||||
node_shape[cnode->fullname_with_scope()] = {static_cast<int>(tensor_info->shape().size())};
|
||||
} else if (node_shape_size.find(cnode->input(1)->fullname_with_scope()) != node_shape_size.end()) {
|
||||
node_shape[cnode->fullname_with_scope()] = {node_shape_size[cnode->input(1)->fullname_with_scope()]};
|
||||
}
|
||||
} else if (prim_type == schema::PrimitiveType_StridedSlice) {
|
||||
node_shape[cnode->fullname_with_scope()] = node_shape[cnode->input(1)->fullname_with_scope()];
|
||||
} else if (prim_type == schema::PrimitiveType_Stack) {
|
||||
auto shape = node_shape[cnode->input(1)->fullname_with_scope()];
|
||||
shape.insert(shape.begin(), cnode->inputs().size() - 1);
|
||||
node_shape[cnode->fullname_with_scope()] = shape;
|
||||
}
|
||||
|
||||
// Get in node shape size
|
||||
std::vector<int> in_shape_sizes;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
int in_shape_size = 0;
|
||||
if (utils::isa<CNodePtr>(cnode->input(i))) {
|
||||
in_shape_size = node_shape_size[cnode->input(i)->fullname_with_scope()];
|
||||
if (prim_type == schema::PrimitiveType_Reshape && i == 2 &&
|
||||
node_shape.find(cnode->input(i)->fullname_with_scope()) != node_shape.end()) {
|
||||
in_shape_size = node_shape[cnode->input(i)->fullname_with_scope()].at(0);
|
||||
}
|
||||
} else {
|
||||
tensor::TensorPtr tensor_info;
|
||||
auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, i);
|
||||
if (ret == RET_OK) {
|
||||
in_shape_size = tensor_info->shape().size();
|
||||
if (prim_type == schema::PrimitiveType_Reshape && i == 2) {
|
||||
in_shape_size = tensor_info->shape().at(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
in_shape_sizes.emplace_back(in_shape_size);
|
||||
}
|
||||
// Cal shape size infer function
|
||||
auto shape_size_infer_func = shape_size_infer_iter->second;
|
||||
auto shape_size = shape_size_infer_iter->second(in_shape_sizes, *prim_t);
|
||||
// Update node shape size map
|
||||
node_shape_size[cnode->fullname_with_scope()] = shape_size;
|
||||
}
|
||||
return node_shape_size;
|
||||
}
|
||||
|
||||
void NormFusion::InitShapeSizeInferFuncMap() {
|
||||
if (!shape_size_infer_registry_.empty()) {
|
||||
return;
|
||||
}
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Activation] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_AddFusion] = BroadcastShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_BiasAdd] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Stack] = StackSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Cast] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Concat] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_ExpandDims] = ExpandDimsShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Fill] = FillShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_LayerNormFusion] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_MatMul] = MatMulShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_MulFusion] = BroadcastShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_OneHot] = OneHotSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_ReduceFusion] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Reshape] = ReShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Shape] = ShapeOpSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_SliceFusion] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Softmax] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Squeeze] = SqueezeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_StridedSlice] = StridedSliceShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Transpose] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_TileFusion] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_SquaredDifference] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_Rsqrt] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_SubFusion] = BroadcastShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_PadFusion] = CommonShapeSizeInfer;
|
||||
shape_size_infer_registry_[schema::PrimitiveType_PowFusion] = CommonShapeSizeInfer;
|
||||
}
|
||||
|
||||
const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
@ -263,7 +439,7 @@ const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNo
|
|||
int begin_norm_axis = 0;
|
||||
int begin_params_axis = 0;
|
||||
schema::PrimitiveType type = schema::PrimitiveType_NONE;
|
||||
if (!CheckPattern(equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) {
|
||||
if (!CheckPattern(func_graph, equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -41,19 +42,23 @@ class NormFusion : public PatternProcessPass {
|
|||
gamma_ = std::make_shared<Var>();
|
||||
beta_ = std::make_shared<Var>();
|
||||
epsilon_ = std::make_shared<Var>();
|
||||
|
||||
InitShapeSizeInferFuncMap();
|
||||
}
|
||||
|
||||
~NormFusion() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
bool GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes,
|
||||
const std::vector<int> ¶ms_shape, schema::PrimitiveType *type, int *begin_norm_axis,
|
||||
int *begin_params_axis) const;
|
||||
bool CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, int *begin_norm_axis,
|
||||
int *begin_params_axis) const;
|
||||
void InitShapeSizeInferFuncMap();
|
||||
bool GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
|
||||
const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape,
|
||||
schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const;
|
||||
bool CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon,
|
||||
int *begin_norm_axis, int *begin_params_axis) const;
|
||||
CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type,
|
||||
float epsilon, int begin_norm_axis, int begin_params_axis) const;
|
||||
std::map<string, int> ShapeSizeInfer(const FuncGraphPtr &func_graph) const;
|
||||
|
||||
protected:
|
||||
VarPtr input_ = nullptr;
|
||||
|
@ -64,6 +69,8 @@ class NormFusion : public PatternProcessPass {
|
|||
VarPtr gamma_ = nullptr;
|
||||
VarPtr beta_ = nullptr;
|
||||
VarPtr epsilon_ = nullptr;
|
||||
std::map<schema::PrimitiveType, std::function<int(std::vector<int>, const schema::PrimitiveT &)>>
|
||||
shape_size_infer_registry_;
|
||||
};
|
||||
|
||||
/// fuse tf layer_norm or instance_norm into one operator
|
||||
|
|
Loading…
Reference in New Issue