From 0e07b0d9e8219f2770732a68cbafb92773f15f51 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Sun, 27 Sep 2020 15:41:22 +0800 Subject: [PATCH] scale add relu and relu6 for fusion and calculating --- mindspore/lite/nnacl/fp32/scale.c | 144 ++++++++++++++++ mindspore/lite/nnacl/fp32/scale.h | 4 + mindspore/lite/nnacl/scale.h | 1 + mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/src/ops/scale.cc | 8 +- mindspore/lite/src/ops/scale.h | 2 + mindspore/lite/src/populate_parameter.cc | 1 + .../lite/src/runtime/kernel/arm/fp32/scale.cc | 93 +++++----- .../lite/src/runtime/kernel/arm/fp32/scale.h | 2 +- .../kernel/arm/fp32/scale_fp32_tests.cc | 160 ++++++++++++++++++ .../fusion/mul_add_fusion_pass.cc | 21 ++- 11 files changed, 387 insertions(+), 50 deletions(-) create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc diff --git a/mindspore/lite/nnacl/fp32/scale.c b/mindspore/lite/nnacl/fp32/scale.c index 0806999f801..99e17f0332b 100644 --- a/mindspore/lite/nnacl/fp32/scale.c +++ b/mindspore/lite/nnacl/fp32/scale.c @@ -77,3 +77,147 @@ void DoScale(float *in_data, float *out_data, float *scale, float *offset, int t scale_param->inner_size_); } } + +void ScaleInnerRelu(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, + int axis_size, int inner_size) { +#ifdef ENABLE_ARM64 + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_ARM64 + for (; in_index < inner_size - 4; in_index += 4) { + int in_offset = axis_offset + in_index; + float32x4_t data = vld1q_f32(in_data + in_offset); + float32x4_t scale_4 = vdupq_n_f32(scale[i]); + float32x4_t offset_4 = vdupq_n_f32(offset[i]); + float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); + float32x4_t result = vmaxq_f32(tmp, zeros); + vst1q_f32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } + } +} + +void ScaleAxisRelu(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, + int axis_size) { +#ifdef ENABLE_ARM64 + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_ARM64 + for (; index < axis_size - 4; index += 4) { + int in_offset = out_offset + index; + float32x4_t data = vld1q_f32(in_data + in_offset); + float32x4_t scale_4 = vld1q_f32(scale + index); + float32x4_t offset_4 = vld1q_f32(offset + index); + float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); + float32x4_t result = vmaxq_f32(tmp, zeros); + vst1q_f32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } +} + +void DoScaleRelu(float *in_data, float *out_data, float *scale, float *offset, int task_id, + ScaleParameter *scale_param) { + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void ScaleInnerRelu6(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, + int axis_size, int inner_size) { +#ifdef ENABLE_ARM64 + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_ARM64 + for (; in_index < inner_size - 4; in_index += 4) { + int in_offset = axis_offset + in_index; + float32x4_t data = vld1q_f32(in_data + in_offset); + float32x4_t scale_4 = vdupq_n_f32(scale[i]); + float32x4_t offset_4 = vdupq_n_f32(offset[i]); + float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); + float32x4_t result = vminq_f32(vmaxq_f32(tmp, zeros), bounds); + vst1q_f32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } + } +} + +void ScaleAxisRelu6(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, + int axis_size) { +#ifdef ENABLE_ARM64 + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_ARM64 + for (; index < axis_size - 4; index += 4) { + int in_offset = out_offset + index; + float32x4_t data = vld1q_f32(in_data + in_offset); + float32x4_t scale_4 = vld1q_f32(scale + index); + float32x4_t offset_4 = vld1q_f32(offset + index); + float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); + float32x4_t result = vminq_f32(vmaxq_f32(tmp, zeros), bounds); + vst1q_f32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } +} + +void DoScaleRelu6(float *in_data, float *out_data, float *scale, float *offset, int task_id, + ScaleParameter *scale_param) { + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} diff --git a/mindspore/lite/nnacl/fp32/scale.h b/mindspore/lite/nnacl/fp32/scale.h index 63ba1dd200b..f1474e421c7 100644 --- a/mindspore/lite/nnacl/fp32/scale.h +++ b/mindspore/lite/nnacl/fp32/scale.h @@ -23,6 +23,10 @@ extern "C" { #endif void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param); +void DoScaleRelu(float *in_data, float *out_data, float *scale, float *offset, int task_id, + ScaleParameter *scale_param); +void DoScaleRelu6(float *in_data, float *out_data, float *scale, float *offset, int task_id, + ScaleParameter *scale_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/scale.h b/mindspore/lite/nnacl/scale.h index 346e6c36d77..fb9c881a93f 100644 --- a/mindspore/lite/nnacl/scale.h +++ b/mindspore/lite/nnacl/scale.h @@ -33,6 +33,7 @@ typedef struct ScaleParameter { int scale_zp_; int offset_zp_; int output_zp_; + int activation_type_; } ScaleParameter; #endif // MINDSPORE_LITE_NNACL_SCALE_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 75fbb7f8443..874c8f399f2 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -416,6 +416,7 @@ table BNGradInput { } table Scale { axis: int; + activationType: ActivationType = 0; } table Eltwise { diff --git a/mindspore/lite/src/ops/scale.cc b/mindspore/lite/src/ops/scale.cc index 53ed368783f..41256f7a779 100644 --- a/mindspore/lite/src/ops/scale.cc +++ b/mindspore/lite/src/ops/scale.cc @@ -20,12 +20,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE int Scale::GetAxis() const { return this->primitive_->value.AsScale()->axis; } - void Scale::SetAxis(int axis) { this->primitive_->value.AsScale()->axis = axis; } +int Scale::GetActivationType() const { return this->primitive_->value.AsScale()->activationType; } +void Scale::SetActivationType(int activation_type) { + this->primitive_->value.AsScale()->activationType = (schema::ActivationType)activation_type; +} #else int Scale::GetAxis() const { return this->primitive_->value_as_Scale()->axis(); } +int Scale::GetActivationType() const { return this->primitive_->value_as_Scale()->activationType(); } int Scale::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); @@ -34,7 +38,7 @@ int Scale::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: MS_LOG(ERROR) << "value_as_Scale return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateScale(*fbb, attr->axis()); + auto val_offset = schema::CreateScale(*fbb, attr->axis(), attr->activationType()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Scale, val_offset.o); fbb->Finish(prim_offset); return RET_OK; diff --git a/mindspore/lite/src/ops/scale.h b/mindspore/lite/src/ops/scale.h index 6ebe6f8b8ae..f7a7ca49857 100644 --- a/mindspore/lite/src/ops/scale.h +++ b/mindspore/lite/src/ops/scale.h @@ -32,6 +32,7 @@ class Scale : public PrimitiveC { Scale() = default; explicit Scale(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetAxis(int axis); + void SetActivationType(int activation_type); #else Scale() = default; @@ -39,6 +40,7 @@ class Scale : public PrimitiveC { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetAxis() const; + int GetActivationType() const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 70ab9b4310e..2297436ca91 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -943,6 +943,7 @@ OpParameter *PopulateScaleParameter(const mindspore::lite::PrimitiveC *primitive scale_param->op_parameter_.type_ = primitive->Type(); auto param = reinterpret_cast(const_cast(primitive)); scale_param->axis_ = param->GetAxis(); + scale_param->activation_type_ = param->GetActivationType(); return reinterpret_cast(scale_param); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc index 58161c8a478..3663a7ab638 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc @@ -35,52 +35,56 @@ ScaleCPUKernel::~ScaleCPUKernel() { scale_ = nullptr; } } - if (offset_ != nullptr) { - free(offset_); - offset_ = nullptr; + if (scale_param_->const_offset_) { + if (offset_ != nullptr) { + free(offset_); + offset_ = nullptr; + } } } int ScaleCPUKernel::InitScaleOffset() { auto scale_tensor = in_tensors_.at(1); - float *scale_ptr = reinterpret_cast(in_tensors_.at(1)->data_c()); - if (scale_ptr != nullptr) { + if (reinterpret_cast(scale_tensor->data_c()) != nullptr) { scale_param_->const_scale_ = true; - if (scale_ != nullptr) { - free(scale_); - scale_ = nullptr; - } scale_ = reinterpret_cast(malloc(scale_tensor->ElementsNum() * sizeof(float))); if (scale_ == nullptr) { MS_LOG(ERROR) << "Malloc buffer failed."; return RET_ERROR; } - memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(float)); + memcpy(scale_, scale_tensor->data_c(), scale_tensor->ElementsNum() * sizeof(float)); } else { scale_param_->const_scale_ = false; scale_ = nullptr; } - if (offset_ != nullptr) { - free(offset_); + if (in_tensors_.size() == 2) { + scale_param_->const_offset_ = true; + offset_ = reinterpret_cast(malloc(scale_tensor->ElementsNum() * sizeof(float))); + if (offset_ == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + memset(offset_, 0, scale_tensor->ElementsNum() * sizeof(float)); + } else if (in_tensors_.size() == 3 && reinterpret_cast(in_tensors_.at(2)->data_c()) != nullptr) { + scale_param_->const_offset_ = true; + auto offset_tensor = in_tensors_.at(2); + MS_ASSERT(scale_tensor->ElementsNum() == offset_tensor->ElementsNum()); + offset_ = reinterpret_cast(malloc(offset_tensor->ElementsNum() * sizeof(float))); + if (offset_ == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float)); + } else { + scale_param_->const_offset_ = false; offset_ = nullptr; } - offset_ = reinterpret_cast(malloc(scale_param_->axis_size_ * sizeof(float))); - if (offset_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; - } - memset(offset_, 0, scale_param_->axis_size_ * sizeof(float)); - if (in_tensors_.size() == 3) { - auto offset_tensor = in_tensors_.at(2); - if (offset_tensor->data_c() != nullptr) { - memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float)); - } - } + return RET_OK; } -int ScaleCPUKernel::InitParameter() { +int ScaleCPUKernel::CalculateParameter() { auto in_tensor = in_tensors_.at(0); auto in_shape = in_tensor->shape(); auto scale_tensor = in_tensors_.at(1); @@ -118,32 +122,44 @@ int ScaleCPUKernel::Init() { MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << in_tensors_.size() << " is given."; return RET_ERROR; } + auto ret = InitScaleOffset(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; + return RET_ERROR; + } if (!InferShapeDone()) { return RET_OK; } - ReSize(); return RET_OK; } int ScaleCPUKernel::ReSize() { - auto ret = InitParameter(); + auto ret = CalculateParameter(); if (ret != RET_OK) { MS_LOG(ERROR) << "Scale fp32 InitParameter failed."; return RET_ERROR; } - ret = InitScaleOffset(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; - return RET_ERROR; - } return RET_OK; } int ScaleCPUKernel::Scale(int task_id) { - DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); + switch (scale_param_->activation_type_) { + case schema::ActivationType_RELU6: + DoScaleRelu6(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); + break; + case schema::ActivationType_RELU: + DoScaleRelu(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); + break; + case schema::ActivationType_NO_ACTIVATION: + DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); + break; + default: + MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; + return RET_ERROR; + } return RET_OK; } @@ -164,14 +180,15 @@ int ScaleCPUKernel::Run() { return ret; } auto in_tensor = in_tensors_.front(); - input_ptr_ = reinterpret_cast(in_tensor->MutableData()); - if (scale_ == nullptr) { + input_ptr_ = reinterpret_cast(in_tensor->data_c()); + if (!scale_param_->const_scale_) { auto scale_tensor = in_tensors_[1]; - scale_ = reinterpret_cast(scale_tensor->MutableData()); + scale_ = reinterpret_cast(scale_tensor->data_c()); } - if (offset_ == nullptr) { + if (!scale_param_->const_offset_) { + MS_ASSERT(in_tensors_.size() == 3); auto offset_tensor = in_tensors_.at(2); - memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float)); + offset_ = reinterpret_cast(offset_tensor->data_c()); } auto out_tensor = out_tensors_.front(); output_ptr_ = reinterpret_cast(out_tensor->MutableData()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h index a31e61d14b9..7fc15319213 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h @@ -36,7 +36,7 @@ class ScaleCPUKernel : public LiteKernel { int Init() override; int ReSize() override; int Run() override; - int InitParameter(); + int CalculateParameter(); int InitScaleOffset(); int Scale(int task_id); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc new file mode 100644 index 00000000000..d1763353e60 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/tensor.h" +#include "common/common_test.h" +#include "nnacl/pad_parameter.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/schema/ops_generated.h" +#include "nnacl/fp32/scale.h" + +using mindspore::schema::ActivationType; +using mindspore::schema::ActivationType_NO_ACTIVATION; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::Format_NHWC; +namespace mindspore { + +class TestScaleFp32 : public mindspore::CommonTest { + public: + TestScaleFp32() = default; + void Prepare(const std::vector &input_shape, const std::vector &scale_shape, + const std::vector &offset_shape, const std::vector &output_shape, float *input_data, + float *scale_data, float *offset_data, float *output_data, int axis, ActivationType act_type, + const int thread_num); + + void TearDown() override; + + public: + float err_tol = 1e-5; + lite::Tensor in_tensor_; + lite::Tensor scale_tensor_; + lite::Tensor offset_tensor_; + lite::Tensor out_tensor_; + ScaleParameter param_; + std::vector inputs_{&in_tensor_, &scale_tensor_, &offset_tensor_}; + std::vector outputs_{&out_tensor_}; + kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Scale}; + lite::InnerContext ctx_ = lite::InnerContext(); + kernel::KernelCreator creator_ = nullptr; + kernel::LiteKernel *kernel_ = nullptr; +}; + +void TestScaleFp32::TearDown() { + in_tensor_.SetData(nullptr); + scale_tensor_.SetData(nullptr); + offset_tensor_.SetData(nullptr); + out_tensor_.SetData(nullptr); +} + +void TestScaleFp32::Prepare(const std::vector &input_shape, const std::vector &scale_shape, + const std::vector &offset_shape, const std::vector &output_shape, + float *input_data, float *scale_data, float *offset_data, float *output_data, int axis, + ActivationType act_type, const int thread_num) { + in_tensor_.set_data_type(kNumberTypeFloat32); + in_tensor_.SetFormat(Format_NHWC); + in_tensor_.set_shape(input_shape); + scale_tensor_.set_data_type(kNumberTypeFloat32); + scale_tensor_.SetFormat(Format_NHWC); + scale_tensor_.set_shape(scale_shape); + offset_tensor_.set_data_type(kNumberTypeFloat32); + offset_tensor_.SetFormat(Format_NHWC); + offset_tensor_.set_shape(offset_shape); + out_tensor_.set_data_type(kNumberTypeFloat32); + out_tensor_.set_shape(output_shape); + + in_tensor_.SetData(input_data); + scale_tensor_.SetData(scale_data); + offset_tensor_.SetData(offset_data); + out_tensor_.SetData(output_data); + + param_.activation_type_ = act_type; + param_.axis_ = axis; + ctx_ = lite::InnerContext(); + ctx_.thread_num_ = thread_num; + ctx_.Init(); + creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc_); + ASSERT_NE(creator_, nullptr); + kernel_ = creator_(inputs_, outputs_, reinterpret_cast(¶m_), &ctx_, desc_, nullptr); + ASSERT_NE(kernel_, nullptr); +} + +TEST_F(TestScaleFp32, ScaleNoAct) { + std::vector input_shape{1, 2, 2, 3}; + std::vector scale_shape{3}; + std::vector offset_shape{3}; + std::vector output_shape{1, 2, 2, 3}; + float in_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; + float scale_data[3] = {1.0, 2.0, 3.0}; + float offset_data[3] = {1.0, 1.0, 1.0}; + float out_data[12] = {0}; + int axis = -1; + int thread_num = 2; + Prepare(input_shape, scale_shape, offset_shape, output_shape, in_data, scale_data, offset_data, out_data, axis, + ActivationType_NO_ACTIVATION, thread_num); + + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + + std::vector expect{1.0, 3.0, 7.0, 4.0, 9.0, 16.0, 7.0, 15.0, 25.0, 10.0, 21.0, 34.0}; + + CompareOutputData(out_data, expect.data(), 12, err_tol); +} + +TEST_F(TestScaleFp32, ScaleRelu) { + std::vector input_shape{1, 2, 2, 3}; + std::vector scale_shape{3}; + std::vector offset_shape{3}; + std::vector output_shape{1, 2, 2, 3}; + float in_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; + float scale_data[3] = {1.0, 2.0, 3.0}; + float offset_data[3] = {-5.0, -5.0, -5.0}; + float out_data[12] = {0}; + int axis = -1; + int thread_num = 2; + Prepare(input_shape, scale_shape, offset_shape, output_shape, in_data, scale_data, offset_data, out_data, axis, + ActivationType_RELU, thread_num); + + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + + std::vector expect{0.0, 0.0, 1.0, 0.0, 3.0, 10.0, 1.0, 9.0, 19.0, 4.0, 15.0, 28.0}; + + CompareOutputData(out_data, expect.data(), 12, err_tol); +} +TEST_F(TestScaleFp32, ScaleRelu6) { + std::vector input_shape{1, 2, 2, 3}; + std::vector scale_shape{3}; + std::vector offset_shape{3}; + std::vector output_shape{1, 2, 2, 3}; + float in_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; + float scale_data[3] = {1.0, 2.0, 3.0}; + float offset_data[3] = {-5.0, -5.0, -5.0}; + float out_data[12] = {0}; + int axis = -1; + int thread_num = 2; + Prepare(input_shape, scale_shape, offset_shape, output_shape, in_data, scale_data, offset_data, out_data, axis, + ActivationType_RELU6, thread_num); + + auto ret = kernel_->Run(); + EXPECT_EQ(0, ret); + + std::vector expect{0.0, 0.0, 1.0, 0.0, 3.0, 6.0, 1.0, 6.0, 6.0, 4.0, 6.0, 6.0}; + + CompareOutputData(out_data, expect.data(), 12, err_tol); +} +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc index f3a406e515b..10f3467c497 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc @@ -144,23 +144,26 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt // NHWC int shape_size = graph->allTensors.at(addBiasIndex)->dims.size(); scaleParam->axis = 0 - shape_size; - mulNode->primitive->value.value = scaleParam.release(); mulNode->inputIndex.push_back(addBiasIndex); - if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) { + auto activationType = addNode->primitive->value.AsAdd()->activationType; + if (activationType == ActivationType_RELU || activationType == ActivationType_RELU6 || + activationType == ActivationType_NO_ACTIVATION) { + // delete addnode + scaleParam->activationType = activationType; + auto status = IsolateOneWayNode(graph, addNode); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed"; + return status; + } + } else { // repace addnode as activation std::unique_ptr activationParam(new ActivationT()); activationParam->type = addNode->primitive->value.AsAdd()->activationType; addNode->primitive->value.type = schema::PrimitiveType_Activation; addNode->primitive->value.value = activationParam.release(); addNode->inputIndex.pop_back(); - return RET_OK; - } - // delete addnode - auto status = IsolateOneWayNode(graph, addNode); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed"; - return status; } + mulNode->primitive->value.value = scaleParam.release(); return RET_OK; } } // namespace lite