convert onnx instance norm to layer norm

This commit is contained in:
xuanyue 2020-12-14 19:52:17 +08:00
parent c1dd0b8e3d
commit a9655ab51c
12 changed files with 117 additions and 36 deletions

View File

@ -19,11 +19,12 @@
#include "nnacl/op_base.h"
int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data,
const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num) {
const float *beta_data, enum ElementwiseMode elementwise_mode, float epsilon, float *dst_data,
size_t task_id, size_t thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (affine && (gamma_data == NULL || beta_data == NULL)) {
if (elementwise_mode != 0 && (gamma_data == NULL || beta_data == NULL)) {
return NNACL_NULL_PTR;
}
@ -63,7 +64,7 @@ int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const
#ifdef ENABLE_NEON
float32x4_t meanv = vdupq_n_f32(mean);
float32x4_t denov = vdupq_n_f32(deno);
if (affine) {
if (elementwise_mode != 0) {
for (; index < inner_size - C8NUM; index += C8NUM) {
float32x4_t srcv1 = vld1q_f32(src + index);
float32x4_t srcv2 = vld1q_f32(src + index + 4);
@ -71,6 +72,14 @@ int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const
float32x4_t outv2 = vsubq_f32(srcv2, meanv);
outv1 = vmulq_f32(outv1, denov);
outv2 = vmulq_f32(outv2, denov);
if (elementwise_mode == 1) {
float32x4_t gammav1 = vdupq_n_f32(gamma_data[j]);
float32x4_t betav1 = vdupq_n_f32(beta_data[j]);
outv1 = vmulq_f32(outv1, gammav1);
outv2 = vmulq_f32(outv2, gammav1);
outv1 = vaddq_f32(outv1, betav1);
outv2 = vaddq_f32(outv2, betav1);
} else {
float32x4_t gammav1 = vld1q_f32(gamma_data + index);
float32x4_t gammav2 = vld1q_f32(gamma_data + index + 4);
float32x4_t betav1 = vld1q_f32(beta_data + index);
@ -79,6 +88,7 @@ int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const
outv2 = vmulq_f32(outv2, gammav2);
outv1 = vaddq_f32(outv1, betav1);
outv2 = vaddq_f32(outv2, betav2);
}
vst1q_f32(dst + index, outv1);
vst1q_f32(dst + index + 4, outv2);
}
@ -97,7 +107,9 @@ int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const
#endif
for (; index < inner_size; index++) {
dst[index] = (src[index] - mean) * deno;
if (affine) {
if (elementwise_mode == 1) {
dst[index] = dst[index] * gamma_data[j] + beta_data[j];
} else if (elementwise_mode == 2) {
dst[index] = dst[index] * gamma_data[index] + beta_data[index];
}
}

View File

@ -24,7 +24,8 @@ extern "C" {
#endif
int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data,
const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num);
const float *beta_data, enum ElementwiseMode elementwise_mode, float epsilon, float *dst_data,
size_t task_id, size_t thread_num);
#ifdef __cplusplus
}
#endif

View File

@ -22,12 +22,13 @@
*
* */
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant, float epsilon) {
enum ElementwiseMode elementwise_mode, int outer_size, int inner_size, LayerNormQuantArg *quant,
float epsilon) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (affine && (gamma_data == NULL || beta_data == NULL)) {
if (elementwise_mode != 0 && (gamma_data == NULL || beta_data == NULL)) {
return NNACL_NULL_PTR;
}
@ -47,7 +48,9 @@ int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *
for (int i = 0; i < inner_size; i++) {
float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_;
float fp32_dst = (fp32_src - mean) * deno;
if (affine) {
if (elementwise_mode == 1) {
fp32_dst = fp32_dst * gamma_data[out_index] + beta_data[out_index];
} else if (elementwise_mode == 2) {
fp32_dst = fp32_dst * gamma_data[i] + beta_data[i];
}
int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_);

View File

@ -25,7 +25,8 @@ extern "C" {
#endif
int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data,
bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant, float epsilon);
enum ElementwiseMode elementwise_mode, int outer_size, int inner_size, LayerNormQuantArg *quant,
float epsilon);
#ifdef __cplusplus
}

View File

@ -19,11 +19,12 @@
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
enum ElementwiseMode { ELEMENTWISE_NOT = 0, ELEMENTWISE_PER_CHANNEL = 1, ELEMENTWISE_PER_NUM = 2 };
typedef struct LayerNormParameter {
// Primitive parameter
OpParameter op_parameter_;
float epsilon_;
bool elementwise_affine_;
enum ElementwiseMode elementwise_mode_;
// shape correlative
int *normalized_shape_;
int normalized_dims_;

View File

@ -89,22 +89,33 @@ int LayerNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite:
MS_LOG(INFO) << "input tensor amount error";
return RET_INPUT_TENSOR_ERROR;
}
auto input_shape = input->shape();
auto normalized_shape = GetNormalizedShape();
if (normalized_shape.size() > input_shape.size() || normalized_shape.size() == 0) {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
}
size_t first_index = input_shape.size() - normalized_shape.size();
for (size_t i = first_index; i < input_shape.size(); ++i) {
if (input_shape.at(i) != normalized_shape.at(i - first_index)) {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
}
}
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
normlized_shape_ = GetNormalizedShape();
elementwise_mode_ = GetElementwiseAffine() ? 2 : 0;
if (normlized_shape_.size() > input_shape.size()) {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
}
if (normlized_shape_.empty()) {
// instance norm -> layernorm
if (input->format() == schema::Format_NCHW) {
normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 2, input_shape.end());
elementwise_mode_ = 1;
} else {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
}
}
size_t first_index = input_shape.size() - normlized_shape_.size();
for (size_t i = first_index; i < input_shape.size(); ++i) {
if (input_shape.at(i) != normlized_shape_.at(i - first_index)) {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
}
}
output->set_shape(input_shape);
return RET_OK;

View File

@ -42,6 +42,12 @@ class LayerNorm : public PrimitiveC {
std::vector<int> GetNormalizedShape() const;
float GetEpsilon() const;
bool GetElementwiseAffine() const;
std::vector<int> normlized_shape() const { return normlized_shape_; }
int elementwise_mode() const { return elementwise_mode_; }
protected:
std::vector<int> normlized_shape_;
int elementwise_mode_ = 0;
};
} // namespace lite
} // namespace mindspore

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/populate/layer_norm_populate.h"
#include "nnacl/layer_norm_parameter.h"
#include <cstdint>
#include "src/ops/layer_norm.h"
@ -31,7 +32,7 @@ OpParameter *PopulateLayerNormParameter(const mindspore::lite::PrimitiveC *primi
memset(layer_norm_parameter, 0, sizeof(LayerNormParameter));
layer_norm_parameter->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::LayerNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto normalized_shape = param->GetNormalizedShape();
auto normalized_shape = param->normlized_shape();
layer_norm_parameter->normalized_dims_ = normalized_shape.size();
if (normalized_shape.size() > SIZE_MAX / sizeof(int)) {
MS_LOG(ERROR) << "normalized_shape size too big";
@ -48,7 +49,7 @@ OpParameter *PopulateLayerNormParameter(const mindspore::lite::PrimitiveC *primi
layer_norm_parameter->normalized_shape_[i] = normalized_shape[i];
}
layer_norm_parameter->epsilon_ = param->GetEpsilon();
layer_norm_parameter->elementwise_affine_ = param->GetElementwiseAffine();
layer_norm_parameter->elementwise_mode_ = static_cast<ElementwiseMode>(param->elementwise_mode());
return reinterpret_cast<OpParameter *>(layer_norm_parameter);
}

View File

@ -0,0 +1,28 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_POPULATE_STRIDED_LAYER_NORM_POPULATE_H_
#define MINDSPORE_LITE_SRC_OPS_POPULATE_STRIDED_LAYER_NORM_POPULATE_H_
#include "src/ops/arithmetic.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateLayerNormParameter(const mindspore::lite::PrimitiveC *primitive);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_POPULATE_STRIDED_LAYER_NORM_POPULATE_H_

View File

@ -18,6 +18,7 @@
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/ops/populate/layer_norm_populate.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -34,6 +35,13 @@ int LayerNormCPUKernel::Init() {
}
int LayerNormCPUKernel::ReSize() {
if (op_parameter_ != nullptr) {
free(op_parameter_);
op_parameter_ = nullptr;
}
op_parameter_ = PopulateLayerNormParameter(primitive_);
op_parameter_->thread_num_ = context_->thread_num_;
param_ = reinterpret_cast<LayerNormParameter *>(op_parameter_);
auto shape = in_tensors_.front()->shape();
outer_size_ = 1;
inner_size_ = 1;
@ -48,7 +56,7 @@ int LayerNormCPUKernel::ReSize() {
}
int LayerNormCPUKernel::DoLayerNorm(int thread_id) {
int ret = LayerNorm(outer_size_, inner_size_, src_data_, gamma_data_, beta_data_, param_->elementwise_affine_,
int ret = LayerNorm(outer_size_, inner_size_, src_data_, gamma_data_, beta_data_, param_->elementwise_mode_,
param_->epsilon_, dst_data_, thread_id, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]";
@ -69,7 +77,7 @@ int LayerNormRun(void *cdata, int task_id) {
int LayerNormCPUKernel::Run() {
src_data_ = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
if (param_->elementwise_affine_) {
if (param_->elementwise_mode_ != 0) {
gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
}

View File

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/int8/layer_norm_int8.h"
#include "src/runtime/runtime_api.h"
#include "src/ops/populate/layer_norm_populate.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@ -23,11 +24,11 @@ using mindspore::schema::PrimitiveType_LayerNorm;
namespace mindspore::kernel {
LayerNormInt8CPUKernel::~LayerNormInt8CPUKernel() {
if (param_->elementwise_affine_ && gamma_ptr_ != nullptr) {
if (param_->elementwise_mode_ != 0 && gamma_ptr_ != nullptr) {
free(gamma_ptr_);
gamma_ptr_ = nullptr;
}
if (param_->elementwise_affine_ && beta_ptr_ != nullptr) {
if (param_->elementwise_mode_ != 0 && beta_ptr_ != nullptr) {
free(beta_ptr_);
beta_ptr_ = nullptr;
}
@ -43,7 +44,7 @@ int LayerNormInt8CPUKernel::SetQuantArgs() {
quant_param_.out_zp_ = output->quant_params().front().zeroPoint;
quant_param_.out_scale_ = output->quant_params().front().scale;
if (param_->elementwise_affine_) {
if (param_->elementwise_mode_ != 0) {
lite::Tensor *gamma_tensor = in_tensors_.at(1);
lite::Tensor *beta_tensor = in_tensors_.at(2);
@ -84,6 +85,13 @@ int LayerNormInt8CPUKernel::Init() {
}
int LayerNormInt8CPUKernel::ReSize() {
if (op_parameter_ != nullptr) {
free(op_parameter_);
op_parameter_ = nullptr;
}
op_parameter_ = PopulateLayerNormParameter(primitive_);
op_parameter_->thread_num_ = context_->thread_num_;
param_ = reinterpret_cast<LayerNormParameter *>(op_parameter_);
auto shape = in_tensors_.front()->shape();
outer_size_ = 1;
inner_size_ = 1;
@ -116,8 +124,8 @@ int LayerNormInt8CPUKernel::DoExecute(int task_id) {
const int8_t *thread_src = src_ptr_ + task_id * param_->thread_outsize_ * inner_size_;
int8_t *thread_dst = dst_ptr_ + task_id * param_->thread_outsize_ * inner_size_;
LayerNormInt8(thread_src, gamma_ptr_, beta_ptr_, thread_dst, param_->elementwise_affine_, current_out_size,
inner_size_, &quant_param_, param_->epsilon_);
LayerNormInt8(thread_src, gamma_ptr_, beta_ptr_, thread_dst, param_->elementwise_mode_, current_out_size, inner_size_,
&quant_param_, param_->epsilon_);
return RET_OK;
}

View File

@ -22,7 +22,7 @@ namespace lite {
lite::PrimitiveC *OnnxInstanceNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx InstanceNormParser";
auto attr = std::make_unique<schema::InstanceNormT>();
auto attr = std::make_unique<schema::LayerNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
@ -39,7 +39,8 @@ lite::PrimitiveC *OnnxInstanceNormParser::ParseLitePrimitive(const onnx::GraphPr
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_InstanceNorm;
attr->elementwiseAffine = true;
primitive->value.type = schema::PrimitiveType_LayerNorm;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}