forked from mindspore-Ecosystem/mindspore
!33802 [MS][LITE] GroupNorm fusion
Merge pull request !33802 from Haim/export_haim
This commit is contained in:
commit
9bc1d22bdd
|
@ -0,0 +1,145 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/group_norm_fp32.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/group_norm_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group,
|
||||
int cur_groups, const GroupNormParameter *param);
|
||||
|
||||
#define SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u) \
|
||||
do { \
|
||||
MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, unit_input + u); \
|
||||
MS_FLOAT_32xN(block_num) norm_val = MS_DIV_F32(block_size, MS_SUB_F32(block_size, input, mean), v_sqrt); \
|
||||
MS_FLOAT_32xN(block_num) output = MS_ADD_F32(block_size, MS_MUL_F32(block_size, norm_val, scale), offset); \
|
||||
MS_ST_F32(block_size, unit_output + u, output); \
|
||||
} while (0)
|
||||
|
||||
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
|
||||
#define SimdFusedGroupNormFp32CoreCalc(block_size, block_num, unit_input, s, m, o, var_sqrt, param, unit_output, u) \
|
||||
do { \
|
||||
MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \
|
||||
MS_FLOAT_32xN(block_num) v_sqrt = MS_MOVN_F32(block_size, var_sqrt); \
|
||||
MS_FLOAT_32xN(block_num) scale = MS_MOVN_F32(block_size, s); \
|
||||
MS_FLOAT_32xN(block_num) offset = MS_MOVN_F32(block_size, o); \
|
||||
for (int block_max_size = param->unit_ - block_num + 1; u < block_max_size; u += block_num) { \
|
||||
SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance,
|
||||
const GroupNormParameter *param, int task_id, float *output) {
|
||||
if (param->op_parameter_.thread_num_ == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
const int frame_elem_num = param->unit_ * param->channel_;
|
||||
const int groups_per_thread = UP_DIV(param->num_groups_, param->op_parameter_.thread_num_);
|
||||
const int completed_group = task_id * groups_per_thread;
|
||||
const int cur_group = MSMIN(groups_per_thread, param->num_groups_ - completed_group);
|
||||
const int num_of_ch_per_group = param->channel_ / param->num_groups_;
|
||||
int cur_offset = completed_group * num_of_ch_per_group * param->unit_;
|
||||
|
||||
for (int b = 0; b < param->batch_; b++) {
|
||||
const float *b_in = input + b * frame_elem_num;
|
||||
float *b_out = output + b * frame_elem_num;
|
||||
int b_offset = cur_offset;
|
||||
GroupNormFp32MeanVar(b_in, mean, variance, completed_group, cur_group, param);
|
||||
for (int g = 0; g < cur_group; g++) {
|
||||
int grp_idx = g + completed_group;
|
||||
int c_offset = grp_idx * num_of_ch_per_group;
|
||||
float m = mean[grp_idx];
|
||||
float v = variance[grp_idx];
|
||||
float variance_sqrt = sqrtf(v + param->epsilon_);
|
||||
if (variance_sqrt == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
for (int c = 0; c < num_of_ch_per_group; c++) {
|
||||
const float *unit_input = b_in + b_offset;
|
||||
float *unit_output = b_out + b_offset;
|
||||
float s = scale[c_offset + c];
|
||||
float o = offset[c_offset + c];
|
||||
int u = 0;
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdFusedGroupNormFp32CoreCalc, unit_input, s, m, o, variance_sqrt, param, unit_output,
|
||||
u);
|
||||
for (; u < param->unit_; u++) {
|
||||
float norm_val = (unit_input[u] - m) / variance_sqrt;
|
||||
unit_output[u] = norm_val * s + o;
|
||||
}
|
||||
b_offset += param->unit_;
|
||||
}
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
#define SimdReduceSum(block_size, block_num, in, i, sum) \
|
||||
do { \
|
||||
for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \
|
||||
MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, in + i); \
|
||||
sum += MS_GET_SUM_F32(block_size, input); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define SimdReduceVar(block_size, block_num, in, m, i, sum) \
|
||||
do { \
|
||||
MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \
|
||||
MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 0); \
|
||||
for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \
|
||||
MS_FLOAT_32xN(block_num) input = MS_SUB_F32(block_size, MS_LD_F32(block_size, in + i), mean); \
|
||||
tmp = MS_ADD_F32(block_size, tmp, MS_MUL_F32(block_size, input, input)); \
|
||||
} \
|
||||
sum += MS_GET_SUM_F32(block_size, tmp); \
|
||||
} while (0)
|
||||
|
||||
static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group,
|
||||
int cur_groups, const GroupNormParameter *param) {
|
||||
const int num_of_ch_per_group = param->channel_ / param->num_groups_;
|
||||
const float N = (float)(param->unit_ * num_of_ch_per_group);
|
||||
|
||||
// calc mean
|
||||
for (int g = 0; g < cur_groups; g++) {
|
||||
int g_idx = g + completed_group;
|
||||
float sum = 0;
|
||||
for (int c = 0; c < num_of_ch_per_group; c++) {
|
||||
const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_;
|
||||
int i = 0;
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdReduceSum, in, i, sum);
|
||||
for (; i < param->unit_; i++) {
|
||||
sum += in[i];
|
||||
}
|
||||
}
|
||||
run_mean[g_idx] = sum / N;
|
||||
}
|
||||
|
||||
// calc variance
|
||||
for (int g = 0; g < cur_groups; g++) {
|
||||
int g_idx = g + completed_group;
|
||||
float var = 0;
|
||||
run_var[g_idx] = 0;
|
||||
for (int c = 0; c < num_of_ch_per_group; c++) {
|
||||
const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_;
|
||||
int i = 0;
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdReduceVar, in, run_mean[g_idx], i, var);
|
||||
for (; i < param->unit_; i++) {
|
||||
var += (in[i] - run_mean[g_idx]) * (in[i] - run_mean[g_idx]);
|
||||
}
|
||||
}
|
||||
run_var[g_idx] = var / N;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
#include "nnacl/group_norm_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance,
|
||||
const GroupNormParameter *param, int task_id, float *output);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_GROUP_NORM_PARAMETER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_GROUP_NORM_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
typedef struct GroupNormParameter {
|
||||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
float epsilon_;
|
||||
int num_groups_;
|
||||
int channel_;
|
||||
int unit_;
|
||||
int batch_;
|
||||
bool affine_;
|
||||
void *mean_;
|
||||
void *variance_;
|
||||
} GroupNormParameter;
|
||||
|
||||
typedef struct GroupNormQuantArg {
|
||||
int32_t in_zp_;
|
||||
int32_t out_zp_;
|
||||
double in_scale_;
|
||||
double out_scale_;
|
||||
} GroupNormQuantArg;
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_GROUP_NORM_PARAMETER_H_
|
|
@ -251,5 +251,6 @@ REG_INFER(Sin, PrimType_Sin, CommonInferShape)
|
|||
REG_INFER(SmoothL1Loss, PrimType_SmoothL1Loss, CommonInferShape)
|
||||
REG_INFER(SmoothL1LossGrad, PrimType_SmoothL1LossGrad, CommonInferShape)
|
||||
REG_INFER(Sqrt, PrimType_Sqrt, CommonInferShape)
|
||||
REG_INFER(SqrtGrad, PrimType_SqrtGrad, CommonInferShape)
|
||||
REG_INFER(Square, PrimType_Square, CommonInferShape)
|
||||
REG_INFER(ZerosLike, PrimType_ZerosLike, CommonInferShape)
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/infer/group_norm_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
|
||||
const TensorC *input = inputs[0];
|
||||
TensorC *output0 = outputs[0];
|
||||
SetDataTypeFormat(output0, input);
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
SetShapeTensor(output0, input);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
REG_INFER(GroupNorm, PrimType_GroupNormFusion, GroupNormInferShape)
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_
|
||||
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_
|
|
@ -94,6 +94,7 @@ static inline float32x4_t vrecp(float32x4_t v) {
|
|||
|
||||
#ifdef ENABLE_ARM64
|
||||
#define MS_GET_MAX128_F32 vmaxvq_f32
|
||||
static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) { return vaddvq_f32(src); }
|
||||
#else
|
||||
static inline float MS_GET_MAX128_F32(MS_FLOAT32X4 src) {
|
||||
float result = MS_F32X4_GETI(src, 0);
|
||||
|
@ -102,7 +103,6 @@ static inline float MS_GET_MAX128_F32(MS_FLOAT32X4 src) {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) {
|
||||
float result = MS_F32X4_GETI(src, 0);
|
||||
|
@ -111,6 +111,7 @@ static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline int32x4_t MS_DIV128_EPI32(int32x4_t src1, int32x4_t src2) {
|
||||
int32x4_t result;
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/kernel/group_norm.h"
|
||||
#include "nnacl/fp32/group_norm_fp32.h"
|
||||
#include "nnacl/group_norm_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
|
||||
static int groupnorm_resize(struct KernelBase *self, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
static int groupnorm_prepare(struct KernelBase *self);
|
||||
static int groupnorm_release(struct KernelBase *self);
|
||||
static int groupnorm_compute(struct KernelBase *self);
|
||||
typedef struct GroupNormStru {
|
||||
KernelBase base;
|
||||
} GroupNormStru;
|
||||
|
||||
static int groupnorm_resize(struct KernelBase *self, TensorC *in[], size_t insize, TensorC *out[], size_t outsize) {
|
||||
GroupNormStru *groupnorm = (GroupNormStru *)self;
|
||||
GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param;
|
||||
|
||||
groupnorm_release(self);
|
||||
|
||||
TensorC *in0 = in[0];
|
||||
if (in0 == NULL) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
if (in0->shape_size_ < C1NUM) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (in0->format_ != NCHW) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
param->unit_ = GetHeight(in0) * GetWidth(in0);
|
||||
param->batch_ = GetBatch(in0);
|
||||
param->channel_ = GetChannel(in0);
|
||||
return groupnorm_prepare(self);
|
||||
}
|
||||
|
||||
static int groupnorm_prepare(struct KernelBase *self) {
|
||||
GroupNormStru *groupnorm = (GroupNormStru *)self;
|
||||
GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param;
|
||||
|
||||
if ((param->num_groups_ < 0) || (param->channel_ % param->num_groups_)) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
size_t mean_var_elem_num = param->num_groups_;
|
||||
param->mean_ = malloc(mean_var_elem_num * sizeof(float));
|
||||
param->variance_ = malloc(mean_var_elem_num * sizeof(float));
|
||||
if (param->mean_ == NULL || param->variance_ == NULL) {
|
||||
groupnorm_release(self);
|
||||
return NNACL_ERR;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
static int groupnorm_release(struct KernelBase *self) {
|
||||
GroupNormStru *groupnorm = (GroupNormStru *)self;
|
||||
GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param;
|
||||
|
||||
if (param->mean_ != NULL) {
|
||||
free(param->mean_);
|
||||
param->mean_ = NULL;
|
||||
}
|
||||
if (param->variance_ != NULL) {
|
||||
free(param->variance_);
|
||||
param->variance_ = NULL;
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
static int groupnorm_do_compute(void *param, int task_id, float lhs_scale, float rhs_scale) {
|
||||
if (param == NULL) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
GroupNormStru *groupnorm_stru = (GroupNormStru *)param;
|
||||
GroupNormParameter *groupnorm_param = (GroupNormParameter *)groupnorm_stru->base.param;
|
||||
int ret = GroupNormFp32(groupnorm_stru->base.in[0]->data_, groupnorm_stru->base.in[C1NUM]->data_,
|
||||
groupnorm_stru->base.in[C2NUM]->data_, groupnorm_param->mean_, groupnorm_param->variance_,
|
||||
groupnorm_param, task_id, groupnorm_stru->base.out[0]->data_);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int groupnorm_compute(struct KernelBase *self) {
|
||||
return self->env->parallelLaunch(self->env->threadPool, groupnorm_do_compute, self, self->param->thread_num_);
|
||||
}
|
||||
|
||||
KernelBase *CreateGroupNorm(OpParameter *param, TensorC **in, size_t insize, TensorC **out, size_t outsize) {
|
||||
GroupNormStru *groupnorm = (GroupNormStru *)malloc(sizeof(GroupNormStru));
|
||||
if (groupnorm == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
groupnorm->base.param = param;
|
||||
groupnorm->base.in = in;
|
||||
groupnorm->base.insize = insize;
|
||||
groupnorm->base.out = out;
|
||||
groupnorm->base.outsize = outsize;
|
||||
groupnorm->base.env = GetExecEnv();
|
||||
groupnorm->base.prepare = groupnorm_prepare;
|
||||
groupnorm->base.resize = groupnorm_resize;
|
||||
groupnorm->base.release = groupnorm_release;
|
||||
groupnorm->base.compute = groupnorm_compute;
|
||||
|
||||
return (void *)groupnorm;
|
||||
}
|
||||
|
||||
REG_KERNEL_CREATOR(PrimType_GroupNormFusion, Format_NCHW, kNumberTypeFloat32, CreateGroupNorm);
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_KERNEL_GROUP_NORM_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_KERNEL_GROUP_NORM_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
#include "nnacl/group_norm_parameter.h"
|
||||
#include "nnacl/kernel.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
KernelBase *CreateGroupNorm(OpParameter *param, TensorC *in[], size_t insize, TensorC *out[], size_t outsize);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_KERNEL_GROUP_NORM_H_
|
|
@ -498,8 +498,9 @@ enum PrimType {
|
|||
PrimType_NLLLossGrad = 208,
|
||||
PrimType_FormatTranspose = 209,
|
||||
PrimType_GatherD = 210,
|
||||
PrimType_GroupNormFusion = 211,
|
||||
PrimType_MIN = PrimType_NONE,
|
||||
PrimType_MAX = PrimType_GatherD + 1,
|
||||
PrimType_MAX = PrimType_GroupNormFusion + 1,
|
||||
|
||||
// inner operators.
|
||||
PrimType_Inner_ToFormat = 10000,
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/fusion/groupnorm_fusion.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(GroupNormFusion, PrimitiveC, BaseOperator);
|
||||
void GroupNormFusion::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsilon, api::MakeValue(epsilon)); }
|
||||
|
||||
void GroupNormFusion::set_num_groups(const int64_t num_groups) {
|
||||
(void)this->AddAttr(kNumGroups, api::MakeValue(num_groups));
|
||||
}
|
||||
|
||||
void GroupNormFusion::set_affine(const bool affine) { (void)this->AddAttr(kAffine, api::MakeValue(affine)); }
|
||||
|
||||
float GroupNormFusion::get_epsilon() const {
|
||||
auto value_ptr = this->GetAttr(kEpsilon);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t GroupNormFusion::get_num_groups() const {
|
||||
auto value_ptr = this->GetAttr(kNumGroups);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool GroupNormFusion::get_affine() const {
|
||||
auto value_ptr = this->GetAttr(kAffine);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void GroupNormFusion::Init(const int64_t num_groups, const float epsilon, bool affine) {
|
||||
this->set_epsilon(epsilon);
|
||||
this->set_num_groups(num_groups);
|
||||
this->set_affine(affine);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameGroupNormFusion, GroupNormFusion);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_FUSION_GROUPNORM_FUSION_H_
|
||||
#define MINDSPORE_CORE_OPS_FUSION_GROUPNORM_FUSION_H_
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameGroupNormFusion = "GroupNormFusion";
|
||||
/// \brief GroupNormFusion defined GroupNormFusion operator prototype of lite.
|
||||
class MIND_API GroupNormFusion : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(GroupNormFusion);
|
||||
/// \brief Constructor.
|
||||
GroupNormFusion() : BaseOperator(kNameGroupNormFusion) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Method to init the op's attributes.
|
||||
///
|
||||
/// \param[in] num_groups Define group number.
|
||||
/// \param[in] eps Define epsilon.
|
||||
void Init(const int64_t num_groups, const float eps = 1e-5, bool affine = true);
|
||||
|
||||
/// \brief Method to set epsilon attribute.
|
||||
///
|
||||
/// \param[in] epsilon Define epsilon for numerical stability.
|
||||
void set_epsilon(const float epsilon);
|
||||
|
||||
/// \brief Method to set num_groups attribute.
|
||||
///
|
||||
/// \param[in] num_groups Define number of groups to separate the channels into.
|
||||
void set_num_groups(const int64_t num_groups);
|
||||
|
||||
/// \brief Method to set affine attribute.
|
||||
///
|
||||
/// \param[in] affine Define whether this ops has learnable parameters.
|
||||
void set_affine(const bool affine);
|
||||
|
||||
/// \brief Method to get epsilon attribute.
|
||||
///
|
||||
/// \return epsilon attribute.
|
||||
float get_epsilon() const;
|
||||
|
||||
/// \brief Method to get num_groups attribute.
|
||||
///
|
||||
/// \return num_groups attribute.
|
||||
int64_t get_num_groups() const;
|
||||
|
||||
/// \brief Method to get affine attribute.
|
||||
///
|
||||
/// \return affine attribute.
|
||||
bool get_affine() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_FUSION_GROUPNORM_FUSION_H_
|
|
@ -264,6 +264,8 @@ constexpr auto kBegin = "begin";
|
|||
constexpr auto kSrcFormat = "src_format";
|
||||
constexpr auto kDstFormat = "dst_format";
|
||||
constexpr auto kLambd = "lambd";
|
||||
constexpr auto kAffine = "affine";
|
||||
constexpr auto kNumGroups = "num_groups";
|
||||
|
||||
enum Index : size_t {
|
||||
kInputIndex0 = 0,
|
||||
|
|
|
@ -228,6 +228,7 @@ union PrimitiveType {
|
|||
NLLLossGrad,
|
||||
FormatTranspose,
|
||||
GatherD,
|
||||
GroupNormFusion,
|
||||
}
|
||||
|
||||
table Abs {
|
||||
|
@ -1276,3 +1277,9 @@ table FormatTranspose {
|
|||
|
||||
table GatherD {
|
||||
}
|
||||
|
||||
table GroupNormFusion {
|
||||
num_groups: long;
|
||||
epsilon: float = 1e-5;
|
||||
affine: bool = true;
|
||||
}
|
||||
|
|
|
@ -228,6 +228,7 @@ OP_TYPE(NLLLoss)
|
|||
OP_TYPE(NLLLossGrad)
|
||||
OP_TYPE(FormatTranspose)
|
||||
OP_TYPE(GatherD)
|
||||
OP_TYPE(GroupNormFusion)
|
||||
OP_TYPE_DEF_END(PrimitiveType)
|
||||
|
||||
OP_SCHEMA_DEF(Abs)
|
||||
|
@ -1276,3 +1277,9 @@ OP_SCHEMA_DEF_END(FormatTranspose)
|
|||
|
||||
OP_SCHEMA_DEF(GatherD)
|
||||
OP_SCHEMA_DEF_END(GatherD)
|
||||
|
||||
OP_SCHEMA_DEF(GroupNormFusion)
|
||||
OP_ATTR(num_groups, long)
|
||||
OP_ATTR_WITH_VALUE(epsilon, float, 1e-5)
|
||||
OP_ATTR_WITH_VALUE(affine, bool, true)
|
||||
OP_SCHEMA_DEF_END(GroupNormFusion)
|
||||
|
|
|
@ -228,6 +228,7 @@
|
|||
#include "ops/fusion/sub_fusion.h"
|
||||
#include "ops/fusion/tile_fusion.h"
|
||||
#include "ops/fusion/topk_fusion.h"
|
||||
#include "ops/fusion/groupnorm_fusion.h"
|
||||
#include "ops/gru.h"
|
||||
#include "ops/non_zero.h"
|
||||
#include "ops/invert_permutation.h"
|
||||
|
@ -483,6 +484,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(NLLLoss)
|
|||
FUNC_MSOP2SCHEMAOP_DECLARE(NLLLossGrad)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(FormatTranspose)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(GatherD)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(GroupNormFusion)
|
||||
#endif
|
||||
} // namespace mindspore::lite::ops
|
||||
#else
|
||||
|
|
|
@ -269,6 +269,7 @@ REG_MINDSPORE_OPERATOR(NLLLoss)
|
|||
REG_MINDSPORE_OPERATOR(NLLLossGrad)
|
||||
REG_MINDSPORE_OPERATOR(FormatTranspose)
|
||||
REG_MINDSPORE_OPERATOR(GatherD)
|
||||
REG_MINDSPORE_OPERATOR(GroupNormFusion)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/group_norm_parameter.h"
|
||||
using mindspore::schema::PrimitiveType_GroupNormFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateIGroupNormParameter(const void *prim) {
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_GroupNormFusion();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = reinterpret_cast<GroupNormParameter *>(malloc(sizeof(GroupNormParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc GroupNormParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(GroupNormParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->epsilon_ = value->epsilon();
|
||||
param->num_groups_ = value->num_groups();
|
||||
param->affine_ = value->affine();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
REG_POPULATE(PrimitiveType_GroupNormFusion, PopulateIGroupNormParameter, SCHEMA_CUR)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/cpu/fp32/groupnorm_fp32.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/common/tensor_util.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_GroupNormFusion;
|
||||
namespace {} // namespace
|
||||
namespace mindspore::kernel {
|
||||
GroupnormCPUKernel::GroupnormCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx) {
|
||||
if (in_tensors_.size() != DIMENSION_3D) {
|
||||
return;
|
||||
}
|
||||
if (out_tensors_.size() != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
in_[i] = reinterpret_cast<TensorC *>(malloc(sizeof(TensorC)));
|
||||
if (in_[i] != nullptr) {
|
||||
Tensor2TensorC(in_tensors_.at(i), in_[i]);
|
||||
}
|
||||
}
|
||||
out_[0] = reinterpret_cast<TensorC *>(malloc(sizeof(TensorC)));
|
||||
if (out_[0] != nullptr) {
|
||||
Tensor2TensorC(out_tensors_.at(0), out_[0]);
|
||||
}
|
||||
}
|
||||
|
||||
GroupnormCPUKernel::~GroupnormCPUKernel() {
|
||||
if (kernel_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
kernel_->release(kernel_);
|
||||
free(kernel_);
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
free(in_[i]);
|
||||
}
|
||||
free(out_[0]);
|
||||
}
|
||||
|
||||
int GroupnormCPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
kernel_ =
|
||||
CreateKernel(op_parameter_, in_, in_tensors().size(), out_, out_tensors_.size(), kNumberTypeFloat32, Format_NCHW);
|
||||
if (kernel_ == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int GroupnormCPUKernel::ReSize() {
|
||||
if (kernel_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return kernel_->resize(kernel_, in_, in_tensors().size(), out_, 1);
|
||||
}
|
||||
|
||||
int GroupnormCPUKernel::Run() {
|
||||
if (kernel_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernel_->in[0]->data_ = in_tensors().at(0)->data();
|
||||
kernel_->in[C1NUM]->data_ = in_tensors().at(C1NUM)->data();
|
||||
kernel_->in[C2NUM]->data_ = in_tensors().at(C2NUM)->data();
|
||||
kernel_->out[0]->data_ = out_tensors().front()->data();
|
||||
return kernel_->compute(kernel_);
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GroupNormFusion, LiteKernelCreator<GroupnormCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GROUPNORM_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GROUPNORM_FP32_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "include/context.h"
|
||||
#include "nnacl/fp32/group_norm_fp32.h"
|
||||
#include "nnacl/fp32/scale_fp32.h"
|
||||
#include "nnacl/group_norm_parameter.h"
|
||||
#include "nnacl/tensor_c.h"
|
||||
#include "nnacl/kernel.h"
|
||||
|
||||
using mindspore::lite::InnerContext;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class GroupnormCPUKernel : public LiteKernel {
|
||||
public:
|
||||
GroupnormCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx);
|
||||
virtual ~GroupnormCPUKernel();
|
||||
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExecute(int task_id);
|
||||
|
||||
protected:
|
||||
KernelBase *kernel_ = nullptr;
|
||||
TensorC *in_[DIMENSION_3D] = {nullptr};
|
||||
TensorC *out_[DIMENSION_1D] = {nullptr};
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GROUPNORM_FP32_H_
|
|
@ -173,7 +173,7 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
SchedCallBack sched_mix_precision_callback_;
|
||||
bool train_mode_ = false;
|
||||
void *tensors_data_ = nullptr;
|
||||
unsigned int tensors_data_size_ = 0;
|
||||
size_t tensors_data_size_ = 0;
|
||||
std::shared_ptr<Allocator> allocator_;
|
||||
};
|
||||
|
||||
|
|
|
@ -90,6 +90,7 @@
|
|||
#include "tools/optimizer/format/to_nhwc_format.h"
|
||||
#include "tools/converter/adapter/acl/acl_pass.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "tools/optimizer/fusion/groupnorm_fusion.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore::lite {
|
||||
|
@ -190,6 +191,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>(config->fmk));
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>(config->fmk));
|
||||
fusion_pm->AddPass(std::make_shared<opt::GroupNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion2>());
|
||||
|
|
|
@ -80,6 +80,7 @@ set(CODER_OPCODERS_SRC
|
|||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/full_connection_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/affine_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "nnacl/fp32/group_norm_fp32.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
#include "coder/opcoders/parallel.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_GroupNormFusion;
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
int GroupNormFP32Coder::Init() {
|
||||
auto gn_parameter = reinterpret_cast<GroupNormParameter *>(OperatorCoder::parameter_);
|
||||
std::vector<int> input_shapes = input_tensor_->shape();
|
||||
if (input_shapes.empty()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
// only NCHW is supported
|
||||
auto fmt = input_tensor_->format();
|
||||
CHECK_NOT_EQUAL_RETURN(fmt, NCHW);
|
||||
|
||||
auto in_n_dim = input_shapes.size();
|
||||
CHECK_LESS_RETURN(in_n_dim, 1);
|
||||
|
||||
gn_parameter->unit_ = input_tensor_->Height() * input_tensor_->Width();
|
||||
gn_parameter->batch_ = input_tensor_->Batch();
|
||||
gn_parameter->channel_ = input_tensor_->Channel();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GroupNormFP32Coder::Prepare(CoderContext *const context) {
|
||||
auto gn_parameter = reinterpret_cast<GroupNormParameter *>(parameter_);
|
||||
int mean_var_size = gn_parameter->num_groups_ * sizeof(float);
|
||||
mean_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, mean_var_size, kWorkspace));
|
||||
MS_CHECK_PTR(mean_);
|
||||
variance_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, mean_var_size, kWorkspace));
|
||||
MS_CHECK_PTR(variance_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GroupNormFP32Coder::DoCode(CoderContext *const context) {
|
||||
// attribute
|
||||
auto gn_parameter = reinterpret_cast<GroupNormParameter *>(parameter_);
|
||||
if (Init() != RET_OK) {
|
||||
MS_LOG(ERROR) << "GroupFP32Coder Init error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE(input_tensors_.size() == DIMENSION_3D, "inputs size is not equal to three");
|
||||
Tensor *scale_tensor = input_tensors_.at(kWeightIndex);
|
||||
Tensor *offset_tensor = input_tensors_.at(kBiasIndex);
|
||||
MS_CHECK_PTR(scale_tensor);
|
||||
MS_CHECK_PTR(offset_tensor);
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/fp32/group_norm_fp32.h",
|
||||
},
|
||||
{
|
||||
"group_norm_fp32.c",
|
||||
});
|
||||
NNaclFp32Serializer code;
|
||||
code.CodeStruct("gn_parameter", *gn_parameter);
|
||||
code.CodeFunction("GroupNormFp32", input_tensor_, scale_tensor, offset_tensor, mean_, variance_, "&gn_parameter",
|
||||
kDefaultTaskId, output_tensor_);
|
||||
MS_LOG(INFO) << "GroupNormFp32Code has been called";
|
||||
context->AppendCode(code.str());
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_GroupNormFusion,
|
||||
CPUOpCoderCreator<GroupNormFP32Coder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GROUPNORM_FP32_CODER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GROUPNORM_FP32_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class GroupNormFP32Coder final : public OperatorCoder {
|
||||
public:
|
||||
GroupNormFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
|
||||
~GroupNormFP32Coder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override;
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
|
||||
private:
|
||||
int Init();
|
||||
|
||||
float *mean_{nullptr};
|
||||
|
||||
float *variance_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GROUPNORM_FP32_CODER_H_
|
|
@ -161,4 +161,9 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceWrappe
|
|||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransFuncStr trans_func_str) {
|
||||
CodeBaseStruct("TransFuncList", name, trans_func_str.in_func_, nullptr, nullptr, trans_func_str.out_func_);
|
||||
}
|
||||
|
||||
void NNaclFp32Serializer::CodeStruct(const std::string &name, const GroupNormParameter &gn_param) {
|
||||
CodeBaseStruct("GroupNormParameter", name, gn_param.op_parameter_, gn_param.epsilon_, gn_param.num_groups_,
|
||||
gn_param.channel_, gn_param.unit_, gn_param.batch_, gn_param.affine_);
|
||||
}
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "nnacl/softmax_parameter.h"
|
||||
#include "nnacl/splice_parameter.h"
|
||||
#include "nnacl/lstm_parameter.h"
|
||||
#include "nnacl/group_norm_parameter.h"
|
||||
#include "wrapper/fp32/dequant_int8_to_fp32_wrapper.h"
|
||||
#include "nnacl/fp32/exp_fp32.h"
|
||||
#include "nnacl/fp32/strided_slice_fp32.h"
|
||||
|
@ -61,6 +62,7 @@ class NNaclFp32Serializer : public Serializer {
|
|||
void CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info);
|
||||
void CodeStruct(const std::string &name, const SpliceWrapperParam &splice_param);
|
||||
void CodeStruct(const std::string &name, const TransFuncStr trans_func_str);
|
||||
void CodeStruct(const std::string &name, const GroupNormParameter &gn_param);
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_SERIALIZERS_NNACL_FP32_ERIALIZER_H_
|
||||
|
|
|
@ -0,0 +1,334 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define USE_DEPRECATED_API
|
||||
#include "tools/optimizer/fusion/groupnorm_fusion.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "ops/fusion/groupnorm_fusion.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
|
||||
MS_ASSERT(axes != nullptr);
|
||||
if (utils::isa<ParameterPtr>(n)) {
|
||||
auto axes_param = utils::cast<ParameterPtr>(n);
|
||||
if (!axes_param->has_default() || axes_param->default_param() == nullptr) {
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
auto axes_value = axes_param->default_param()->cast<tensor::TensorPtr>();
|
||||
if (axes_value == nullptr) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (axes_value->data_type() != kNumberTypeInt && axes_value->data_type() != kNumberTypeInt32) {
|
||||
MS_LOG(ERROR) << "reduce's axes should be integer, now is " << axes_value->data_type();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (axes_value->data_c() == nullptr) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (axes_value->shape().size() > 1) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
axes->resize(1);
|
||||
if (!axes_value->shape().empty()) {
|
||||
MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
|
||||
axes->resize(static_cast<size_t>(axes_value->shape()[0]));
|
||||
}
|
||||
if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
}
|
||||
if (utils::isa<ValueNodePtr>(n)) {
|
||||
auto axes_value_node = utils::cast<ValueNodePtr>(n);
|
||||
*axes = CastToInt(axes_value_node->value());
|
||||
return lite::RET_OK;
|
||||
}
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
bool IsReduceSumNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes,
|
||||
std::vector<int> *axes) {
|
||||
MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
|
||||
auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
|
||||
MS_ASSERT(reduce_value != nullptr);
|
||||
auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
|
||||
MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
|
||||
auto mean2_primitive_c = mean2_primitive->GetPrim();
|
||||
if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Sum) {
|
||||
return false;
|
||||
}
|
||||
if (GetAxis((*equiv)[input_axes], axes) != lite::RET_OK) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsReduceMeanNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes,
|
||||
std::vector<int> *axes) {
|
||||
MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
|
||||
auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
|
||||
MS_ASSERT(reduce_value != nullptr);
|
||||
auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
|
||||
MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
|
||||
auto mean2_primitive_c = mean2_primitive->GetPrim();
|
||||
if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Mean) {
|
||||
return false;
|
||||
}
|
||||
if (GetAxis((*equiv)[input_axes], axes) != lite::RET_OK) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool GroupNormFusion::Init() const {
|
||||
input_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(input_ != nullptr, false);
|
||||
mean1_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(mean1_ != nullptr, false);
|
||||
mean1_axis_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(mean1_axis_ != nullptr, false);
|
||||
sum1_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(sum1_ != nullptr, false);
|
||||
sum1_axis_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(sum1_axis_ != nullptr, false);
|
||||
reshape1_axis_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(reshape1_axis_ != nullptr, false);
|
||||
reshape2_axis_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(reshape2_axis_ != nullptr, false);
|
||||
gamma_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(gamma_ != nullptr, false);
|
||||
beta_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(beta_ != nullptr, false);
|
||||
epsilon_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(epsilon_ != nullptr, false);
|
||||
real_div_divider_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(real_div_divider_ != nullptr, false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GroupNormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *num_groups,
|
||||
float *epsilon, bool *affine) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(epsilon != nullptr);
|
||||
MS_ASSERT(num_groups != nullptr);
|
||||
MS_ASSERT(epsilon != nullptr);
|
||||
MS_ASSERT(affine != nullptr);
|
||||
|
||||
auto input_node_dbg = utils::cast<AnfNodePtr>((*equiv)[input_]);
|
||||
// beta
|
||||
auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
|
||||
MS_ASSERT(beta_node != nullptr);
|
||||
if (!beta_node->isa<Parameter>()) {
|
||||
return false;
|
||||
}
|
||||
auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
|
||||
MS_CHECK_TRUE_RET(beta_param != nullptr, false);
|
||||
auto beta_tensor = beta_param->cast<tensor::TensorPtr>();
|
||||
MS_CHECK_TRUE_RET(beta_tensor != nullptr, false);
|
||||
std::vector<int> beta_shape;
|
||||
std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape),
|
||||
[](int64_t val) { return static_cast<int>(val); });
|
||||
// gamma
|
||||
auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
|
||||
MS_ASSERT(gamma_node != nullptr);
|
||||
if (!gamma_node->isa<Parameter>()) {
|
||||
return false;
|
||||
}
|
||||
auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
|
||||
MS_CHECK_TRUE_RET(gamma_param != nullptr, false);
|
||||
auto gamma_tensor = gamma_param->cast<tensor::TensorPtr>();
|
||||
MS_CHECK_TRUE_RET(gamma_tensor != nullptr, false);
|
||||
std::vector<int> gamma_shape;
|
||||
std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape),
|
||||
[](int64_t val) { return static_cast<int>(val); });
|
||||
// epsilon
|
||||
auto epsilon_node = utils::cast<AnfNodePtr>((*equiv)[epsilon_]);
|
||||
MS_ASSERT(epsilon_node != nullptr);
|
||||
if (!epsilon_node->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto epsilon_value_node = epsilon_node->cast<ValueNodePtr>();
|
||||
MS_CHECK_TRUE_RET(epsilon_value_node != nullptr, false);
|
||||
auto epsilon_value = epsilon_value_node->value();
|
||||
MS_CHECK_TRUE_RET(epsilon_value != nullptr, false);
|
||||
if (!epsilon_value->isa<tensor::Tensor>()) {
|
||||
std::cout << "CheckPattern:epsilon_value_node not tensor" << std::endl;
|
||||
return false;
|
||||
}
|
||||
auto epsilon_tensor = epsilon_value->cast<tensor::TensorPtr>();
|
||||
MS_CHECK_TRUE_RET(epsilon_tensor != nullptr, false);
|
||||
TypeId tensor_type = epsilon_tensor->Dtype()->type_id();
|
||||
if (!(tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
|
||||
std::cout << "CheckPattern:epsilon_value_node not float" << std::endl;
|
||||
|
||||
return false;
|
||||
}
|
||||
auto epsilon_shape = epsilon_tensor->shape();
|
||||
// sum1
|
||||
std::vector<int> sum1_axes;
|
||||
if (!IsReduceSumNode(equiv, sum1_, sum1_axis_, &sum1_axes)) {
|
||||
return false;
|
||||
}
|
||||
// mean1
|
||||
std::vector<int> mean1_axes;
|
||||
if (!IsReduceMeanNode(equiv, mean1_, mean1_axis_, &mean1_axes)) {
|
||||
return false;
|
||||
}
|
||||
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
if (!utils::isa<CNodePtr>(input_node)) {
|
||||
return false;
|
||||
}
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
if (mean1_axes != sum1_axes) {
|
||||
return false;
|
||||
}
|
||||
if (gamma_shape != beta_shape) {
|
||||
return false;
|
||||
}
|
||||
if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
|
||||
MS_CHECK_TRUE_RET(epsilon_tensor->data_c() != nullptr, false);
|
||||
auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->data_c());
|
||||
*epsilon = epsilon_data[0];
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
std::vector<int> reshape1_axes;
|
||||
if (GetAxis((*equiv)[reshape1_axis_], &reshape1_axes) != lite::RET_OK) {
|
||||
return false;
|
||||
}
|
||||
if (reshape1_axes.size() != C3NUM) {
|
||||
return false;
|
||||
}
|
||||
*num_groups = reshape1_axes.at(C1NUM);
|
||||
*affine = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
CNodePtr GroupNormFusion::CreateGroupNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int num_groups,
|
||||
float epsilon, bool affine) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
PrimitiveCPtr primitive_c = nullptr;
|
||||
|
||||
auto layer_norm_primitive = std::make_shared<ops::GroupNormFusion>();
|
||||
MS_CHECK_TRUE_RET(layer_norm_primitive != nullptr, nullptr);
|
||||
layer_norm_primitive->Init(num_groups, epsilon, true);
|
||||
auto layer_norm_primitive_c = layer_norm_primitive->GetPrim();
|
||||
MS_CHECK_TRUE_RET(layer_norm_primitive_c != nullptr, nullptr);
|
||||
primitive_c = layer_norm_primitive_c;
|
||||
|
||||
auto value_node = NewValueNode(primitive_c);
|
||||
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
|
||||
std::vector<AnfNodePtr> new_node_inputs = {value_node};
|
||||
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
new_node_inputs.push_back(input_node);
|
||||
auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
|
||||
MS_ASSERT(gamma_node != nullptr);
|
||||
new_node_inputs.push_back(gamma_node);
|
||||
auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
|
||||
MS_ASSERT(beta_node != nullptr);
|
||||
new_node_inputs.push_back(beta_node);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr GroupNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
MS_LOG(ERROR) << "input param is nullptr, do group norm fusion failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto add2_cnode = node->cast<CNodePtr>();
|
||||
if (IsMarkedTrainOp(add2_cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
float epsilon = 0.0f;
|
||||
int num_groups = 0;
|
||||
bool affine = true;
|
||||
if (!CheckPattern(func_graph, equiv, &num_groups, &epsilon, &affine)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto norm_cnode = CreateGroupNormNode(func_graph, equiv, num_groups, epsilon, affine);
|
||||
if (norm_cnode == nullptr) {
|
||||
MS_LOG(DEBUG) << "create norm cnode failed";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(add2_cnode->abstract() != nullptr, nullptr);
|
||||
norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
|
||||
norm_cnode->set_fullname_with_scope("group_norm_" + add2_cnode->fullname_with_scope());
|
||||
MS_LOG(DEBUG) << "group_norm_ node:" << norm_cnode->fullname_with_scope() << " fusion success";
|
||||
return norm_cnode;
|
||||
}
|
||||
|
||||
const BaseRef GroupNormFusion::DefinePattern() const {
|
||||
if (!Init()) {
|
||||
MS_LOG(ERROR) << "initial member failed.";
|
||||
return {};
|
||||
}
|
||||
|
||||
auto is_reshape1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
|
||||
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
|
||||
VectorRef reshape_ref1 = VectorRef({is_reshape1, input_, reshape1_axis_});
|
||||
VectorRef mean1_ref = VectorRef({mean1_, reshape_ref1, mean1_axis_});
|
||||
auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
|
||||
MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
|
||||
VectorRef sub1_ref = VectorRef({is_sub1, reshape_ref1, mean1_ref});
|
||||
|
||||
auto is_sqare = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquare>);
|
||||
MS_CHECK_TRUE_RET(is_sqare != nullptr, {});
|
||||
VectorRef square_ref = VectorRef({is_sqare, sub1_ref});
|
||||
VectorRef sum1_ref = VectorRef({sum1_, square_ref, sum1_axis_});
|
||||
auto is_realdiv1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
|
||||
MS_CHECK_TRUE_RET(is_realdiv1 != nullptr, {});
|
||||
VectorRef realdiv1_ref = VectorRef({is_realdiv1, sum1_ref, real_div_divider_});
|
||||
auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
|
||||
MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
|
||||
VectorRef add1_ref = VectorRef({is_add1, realdiv1_ref, epsilon_});
|
||||
auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
|
||||
MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
|
||||
VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
|
||||
auto is_realdiv2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
|
||||
MS_CHECK_TRUE_RET(is_realdiv2 != nullptr, {});
|
||||
VectorRef realdiv2_ref = VectorRef({is_realdiv2, sub1_ref, sqrt_ref});
|
||||
|
||||
auto is_reshape2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
|
||||
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
|
||||
VectorRef reshape_ref2 = VectorRef({is_reshape2, realdiv2_ref, reshape2_axis_});
|
||||
auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
|
||||
MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
|
||||
VectorRef mul1_ref = VectorRef({is_mul1, reshape_ref2, gamma_});
|
||||
auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
|
||||
MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
|
||||
VectorRef add2_ref = VectorRef({is_add2, mul1_ref, beta_});
|
||||
return add2_ref;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GROUPNORM_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GROUPNORM_FUSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
/// fuse layer_norm or instance_norm into one operator
|
||||
class GroupNormFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit GroupNormFusion(const std::string &name = "GroupNormFusion", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
|
||||
~GroupNormFusion() override = default;
|
||||
|
||||
protected:
|
||||
bool Init() const;
|
||||
|
||||
private:
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
bool CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *num_groups, float *epsilon,
|
||||
bool *affine) const;
|
||||
CNodePtr CreateGroupNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int num_groups, float epsilon,
|
||||
bool affine) const;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
||||
protected:
|
||||
mutable VarPtr input_ = nullptr;
|
||||
mutable VarPtr mean1_ = nullptr;
|
||||
mutable VarPtr mean1_axis_ = nullptr;
|
||||
mutable VarPtr sum1_ = nullptr;
|
||||
mutable VarPtr sum1_axis_ = nullptr;
|
||||
mutable VarPtr gamma_ = nullptr;
|
||||
mutable VarPtr beta_ = nullptr;
|
||||
mutable VarPtr epsilon_ = nullptr;
|
||||
mutable VarPtr reshape1_axis_ = nullptr;
|
||||
mutable VarPtr reshape2_axis_ = nullptr;
|
||||
mutable VarPtr real_div_divider_ = nullptr;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GROUPNORM_FUSION_H_
|
Loading…
Reference in New Issue