caffeprelu rename to prelu

This commit is contained in:
chenjianping 2020-08-28 16:28:49 +08:00
parent 29070d60a1
commit 34f21226a8
24 changed files with 71 additions and 243 deletions

View File

@ -80,7 +80,7 @@ union PrimitiveType {
Pad,
Maximum,
Minimum,
CaffePReLU,
PReLU,
LeakyReLU,
ArgMax,
ArgMin,
@ -126,7 +126,6 @@ union PrimitiveType {
Broadcast,
BroadcastTo,
Lrn,
Prelu,
ZerosLike,
TopK,
SpaceToDepth,

View File

@ -540,7 +540,7 @@ table MatMul {
transposeB : bool = false;
}
table CaffePReLU {
table PReLU {
channelShared : bool = false;
slope: [float];
}
@ -650,10 +650,6 @@ table Reduce {
mode: ReduceMode;
}
table Prelu {
slope: [float];
}
table Transpose {
perm: [int];
conjugate: bool = false;

View File

@ -14,20 +14,20 @@
* limitations under the License.
*/
#include "src/ops/caffe_p_relu.h"
#include "src/ops/p_relu.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value.AsCaffePReLU()->channelShared; }
bool PReLU::GetChannelShared() const { return this->primitive_->value.AsPReLU()->channelShared; }
void CaffePReLU::SetChannelShared(bool channel_shared) {
this->primitive_->value.AsCaffePReLU()->channelShared = channel_shared;
void PReLU::SetChannelShared(bool channel_shared) {
this->primitive_->value.AsPReLU()->channelShared = channel_shared;
}
#else
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value_as_CaffePReLU()->channelShared(); }
bool PReLU::GetChannelShared() const { return this->primitive_->value_as_PReLU()->channelShared(); }
#endif
} // namespace lite

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
#ifndef LITE_MINDSPORE_LITE_C_OPS_P_RELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_P_RELU_H_
#include <vector>
#include <set>
@ -26,21 +26,21 @@
namespace mindspore {
namespace lite {
class CaffePReLU : public Activation {
class PReLU : public Activation {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(CaffePReLU, Activation);
CaffePReLU() = default;
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
MS_DECLARE_PARENT(PReLU, Activation);
PReLU() = default;
explicit PReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
void SetChannelShared(bool channel_shared);
#else
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}
explicit PReLU(schema::Primitive *primitive) : Activation(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_CaffePReLU();
auto attr = primitive->value_as_PReLU();
MS_ASSERT(attr != nullptr);
auto slope = std::make_unique<std::vector<float>>();
@ -48,8 +48,8 @@ class CaffePReLU : public Activation {
slope->push_back(attr->slope()->data()[i]);
}
auto val_offset = schema::CreateCaffePReLUDirect(fbb, attr->channelShared(), slope.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_CaffePReLU, val_offset.o);
auto val_offset = schema::CreatePReLUDirect(fbb, attr->channelShared(), slope.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PReLU, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
@ -70,4 +70,4 @@ class CaffePReLU : public Activation {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
#endif // LITE_MINDSPORE_LITE_C_OPS_P_RELU_H_

View File

@ -1,35 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/prelu.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<float> Prelu::GetSlope() const { return this->primitive_->value.AsPrelu()->slope; }
void Prelu::SetSlope(const std::vector<float> &slope) { this->primitive_->value.AsPrelu()->slope = slope; }
#else
std::vector<float> Prelu::GetSlope() const {
auto fb_vector = this->primitive_->value_as_Prelu()->slope();
return std::vector<float>(fb_vector->begin(), fb_vector->end());
}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -1,72 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_PRELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_PRELU_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "ir/dtype/type_id.h"
#include "src/ops/activation.h"
namespace mindspore {
namespace lite {
class Prelu : public Activation {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Prelu, PrimitiveC);
Prelu() = default;
explicit Prelu(schema::PrimitiveT *primitive) : Activation(primitive) {}
void SetSlope(const std::vector<float> &slope);
#else
explicit Prelu(schema::Primitive *primitive) : Activation(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Prelu();
MS_ASSERT(attr != nullptr);
auto slope = std::make_unique<std::vector<float>>();
for (int i = 0; i < static_cast<int>(attr->slope()->size()); i++) {
slope->push_back(attr->slope()->data()[i]);
}
auto val_offset = schema::CreatePreluDirect(fbb, slope.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Prelu, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
delete[] buf_bak;
fbb.Clear();
return prim;
}
#endif
std::vector<float> GetSlope() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_PRELU_H_

View File

@ -72,8 +72,8 @@
#include "src/ops/gather_nd.h"
#include "src/ops/local_response_normalization.h"
#include "src/ops/pad.h"
#include "src/ops/prelu.h"
#include "src/ops/caffe_p_relu.h"
#include "src/ops/p_relu.h"
#include "src/ops/leaky_relu.h"
#include "src/ops/reverse_sequence.h"
#include "src/ops/dedepthwise_conv2d.h"
#include "src/ops/depthwise_conv2d.h"
@ -346,10 +346,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return new Minimum(primitive);
case schema::PrimitiveType_StridedSlice:
return new StridedSlice(primitive);
case schema::PrimitiveType_Prelu:
return new Prelu(primitive);
case schema::PrimitiveType_CaffePReLU:
return new CaffePReLU(primitive);
case schema::PrimitiveType_LeakyReLU:
return new (std::nothrow) LeakyReLU(primitive);
case schema::PrimitiveType_PReLU:
return new (std::nothrow) PReLU(primitive);
case schema::PrimitiveType_Round:
return new Round(primitive);
case schema::PrimitiveType_Reverse:
@ -554,10 +554,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new Minimum(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_StridedSlice:
return new StridedSlice(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Prelu:
return new Prelu(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_CaffePReLU:
return new CaffePReLU(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_LeakyReLU:
return new (std::nothrow) LeakyReLU(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_PReLU:
return new (std::nothrow) PReLU(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Round:
return new Round(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Reverse:

View File

@ -75,8 +75,8 @@
#include "src/ops/gather_nd.h"
#include "src/ops/local_response_normalization.h"
#include "src/ops/pad.h"
#include "src/ops/prelu.h"
#include "src/ops/caffe_p_relu.h"
#include "src/ops/leaky_relu.h"
#include "src/ops/p_relu.h"
#include "src/ops/reverse_sequence.h"
#include "src/ops/dedepthwise_conv2d.h"
#include "src/ops/depthwise_conv2d.h"
@ -233,7 +233,7 @@ OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *prim
}
OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive) {
auto param = dynamic_cast<const mindspore::lite::CaffePReLU *>(primitive);
auto param = dynamic_cast<const mindspore::lite::PReLU *>(primitive);
PReluParameter *prelu_param = reinterpret_cast<PReluParameter *>(malloc(sizeof(PReluParameter)));
if (prelu_param == nullptr) {
MS_LOG(ERROR) << "malloc PReluParameter failed.";
@ -246,7 +246,7 @@ OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive
}
OpParameter *PopulateLeakyReluParameter(const mindspore::lite::PrimitiveC *primitive) {
auto param = dynamic_cast<const mindspore::lite::Prelu *>(primitive);
auto param = dynamic_cast<const mindspore::lite::LeakyReLU *>(primitive);
LeakyReluParameter *leaky_relu_param = reinterpret_cast<LeakyReluParameter *>(malloc(sizeof(LeakyReluParameter)));
if (leaky_relu_param == nullptr) {
MS_LOG(ERROR) << "malloc LeakyReluParameter failed.";
@ -254,17 +254,14 @@ OpParameter *PopulateLeakyReluParameter(const mindspore::lite::PrimitiveC *primi
}
memset(leaky_relu_param, 0, sizeof(LeakyReluParameter));
leaky_relu_param->op_parameter_.type_ = primitive->Type();
auto temp = param->GetSlope();
leaky_relu_param->slope_ = reinterpret_cast<float *>(malloc(temp.size() * sizeof(float)));
leaky_relu_param->slope_ = reinterpret_cast<float *>(malloc(sizeof(float)));
if (leaky_relu_param->slope_ == nullptr) {
MS_LOG(ERROR) << "malloc relu slope fail!";
free(leaky_relu_param);
return nullptr;
}
for (size_t i = 0; i < temp.size(); i++) {
leaky_relu_param->slope_[i] = temp[i];
}
leaky_relu_param->slope_num_ = temp.size();
leaky_relu_param->slope_[0] = param->GetNegativeSlope();
leaky_relu_param->slope_num_ = 1;
return reinterpret_cast<OpParameter *>(leaky_relu_param);
}
@ -1598,8 +1595,8 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_ScatterND] = PopulateScatterNDParameter;
populate_parameter_funcs_[schema::PrimitiveType_Squeeze] = PopulateSqueezeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter;
populate_parameter_funcs_[schema::PrimitiveType_CaffePReLU] = PopulatePReLUParameter;
populate_parameter_funcs_[schema::PrimitiveType_Prelu] = PopulateLeakyReluParameter;
populate_parameter_funcs_[schema::PrimitiveType_PReLU] = PopulatePReLUParameter;
populate_parameter_funcs_[schema::PrimitiveType_LeakyReLU] = PopulateLeakyReluParameter;
populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter;
populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter;
populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter;

View File

@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_LeakyReLU;
namespace mindspore::kernel {
int LeakyReluBaseCPUKernel::Init() { return RET_OK; }
kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
kernel::LiteKernel *CpuLeakyReluInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc,
@ -41,7 +41,7 @@ kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Te
MS_ASSERT(desc.type == schema::PrimitiveType_LeakyRelu);
auto *kernel = new (std::nothrow) LeakyReluInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PreluCPUKernel fail!";
MS_LOG(ERROR) << "new LeakyReluInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
@ -54,5 +54,5 @@ kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LeakyReLU, CpuPreluInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LeakyReLU, CpuLeakyReluInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -38,7 +38,10 @@ int GatherCPUKernel::Init() {
}
GatherCPUKernel::~GatherCPUKernel() {
context_->allocator->Free(indices_data_);
if (indices_data_ != nullptr) {
free(indices_data_);
indices_data_ = nullptr;
}
}
int GatherCPUKernel::ReSize() { return RET_OK; }
@ -102,7 +105,7 @@ int GatherCPUKernel::Run() {
}
auto indices_tensor = in_tensors_.at(1);
indices_data_ = reinterpret_cast<int *>(context_->allocator->Malloc(indices_tensor->Size()));
indices_data_ = reinterpret_cast<int *>(malloc(indices_tensor->Size()));
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR;

View File

@ -36,7 +36,7 @@ class GatherCPUKernel : public LiteKernel {
int DoGather(int task_id);
private:
int *indices_data_;
int *indices_data_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -26,7 +26,6 @@ using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LeakyReLU;
using mindspore::schema::PrimitiveType_Prelu;
namespace mindspore::kernel {
namespace {
@ -100,5 +99,4 @@ kernel::LiteKernel *CpuLeakyReluFp32KernelCreator(const std::vector<lite::tensor
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, CpuLeakyReluFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Prelu, CpuLeakyReluFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -24,7 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_CaffePReLU;
using mindspore::schema::PrimitiveType_PReLU;
namespace mindspore::kernel {
namespace {
@ -155,7 +155,7 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te
MS_LOG(ERROR) << "input param is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Prelu);
auto *kernel = new (std::nothrow) PReluCPUKernel(param, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PReluCPUKernel fail!";
@ -171,5 +171,5 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector<lite::tensor::Te
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CaffePReLU, CpuPReluFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PReLU, CpuPReluFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -25,9 +25,20 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Prelu;
namespace mindspore::kernel {
namespace {
int LeakyReluInt8Run(void *cdata, int task_id) {
if (cdata == nullptr) {
MS_LOG(ERROR) << "input cdata is nullptr!";
return RET_ERROR;
}
auto relu = reinterpret_cast<LeakyReluInt8CPUKernel *>(cdata);
relu->DoExecute(task_id);
return RET_OK;
}
} // namespace
int LeakyReluInt8CPUKernel::Init() {
LeakyReluBaseCPUKernel::Init();
LeakyReluParameter *param = reinterpret_cast<LeakyReluParameter *>(op_parameter_);
@ -82,17 +93,12 @@ int LeakyReluInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
ret = ParallelLaunch(THREAD_POOL_DEFAULT, PreluInt8Run, this, op_parameter_->thread_num_);
ret = ParallelLaunch(THREAD_POOL_DEFAULT, LeakyReluInt8Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "RunPreluParam failed. errorcode: ";
}
return RET_OK;
}
int PreluInt8Run(void *cdata, int task_id) {
auto prelu = reinterpret_cast<LeakyReluInt8CPUKernel *>(cdata);
prelu->DoExecute(task_id);
return RET_OK;
}
int LeakyReluInt8CPUKernel::DoExecute(int task_id) {
auto input_tensor = in_tensors_.at(kInputIndex);

View File

@ -41,7 +41,6 @@ class LeakyReluInt8CPUKernel : public LeakyReluBaseCPUKernel {
private:
LeakyReluQuantArg quant_prelu_parm_;
};
int PreluInt8Run(void *cdata, int task_id);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PRELU_INT8_H_

View File

@ -29,7 +29,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Prelu;
using mindspore::schema::PrimitiveType_PReLU;
namespace mindspore::kernel {
@ -154,5 +154,5 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::tensor::Ten
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Prelu, OpenCLPReluKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
} // namespace mindspore::kernel

View File

@ -65,14 +65,14 @@ TEST_F(TestPreluInt8, prelu_1) {
outputs_tensor[0] = output0_tensor;
LeakyReluQuantArg op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_Prelu;
op_param.op_parameter_.type_ = schema::PrimitiveType_LeakyReLU;
op_param.slope_ = reinterpret_cast<float *>(malloc(sizeof(float)));
op_param.slope_[0] = 0.25;
lite::Context *ctx = new lite::Context;
ctx->thread_num_ = 2;
op_param.axis_ = 0.25;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Prelu};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_LeakyReLU};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel =

View File

@ -119,15 +119,6 @@ TEST_F(TestTfliteParserPrelu, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Prelu) << "wrong Op Type";
}
TEST_F(TestTfliteParserPrelu, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value;
std::vector<float> slope(20, 0);
ASSERT_EQ(val.AsPrelu()->slope, slope);
ASSERT_EQ(val.type, schema::PrimitiveType_Prelu);
}
class TestTfliteParserLeakyRelu : public TestTfliteParser {

View File

@ -29,7 +29,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_CaffePReLU};
schema::PrimitiveType_PReLU};
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
schema::PrimitiveType_Concat, schema::PrimitiveType_Add,

View File

@ -34,7 +34,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto,
return RET_NULL_PTR;
}
std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>();
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
@ -60,7 +60,7 @@ STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto,
weightVec->push_back(slope);
op->name = proto.name();
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
op->primitive->value.type = schema::PrimitiveType_PReLU;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -73,7 +73,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
MS_LOG(ERROR) << "input num should be 2";
return RET_ERROR;
}
std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>();
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
std::vector<onnx::TensorProto> params;
const auto &input_name = onnx_node.input(1);
for (const auto &it : onnx_graph.initializer()) {
@ -102,7 +102,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
op->primitive->value.type = schema::PrimitiveType_PReLU;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -84,52 +84,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_OK;
}
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TflitePreluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PreluT> attr = std::make_unique<schema::PreluT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
return RET_OK;
}
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
} // namespace lite
} // namespace mindspore

View File

@ -68,18 +68,6 @@ class TfliteLeakyReluParser : public TfliteActivationParser {
TfliteLeakyReluParser() : TfliteActivationParser() {}
};
class TflitePreluParser : public TfliteNodeParser {
public:
TflitePreluParser() : TfliteNodeParser("Prelu") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op,
std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -107,7 +107,6 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{tflite::BuiltinOperator_DEPTH_TO_SPACE, "DepthToSpace"},
{tflite::BuiltinOperator_SPACE_TO_BATCH_ND, "SpaceToBatchND"},
{tflite::BuiltinOperator_SPACE_TO_DEPTH, "SpaceToDepth"},
{tflite::BuiltinOperator_PRELU, "Prelu"},
{tflite::BuiltinOperator_ROUND, "Round"},
{tflite::BuiltinOperator_WHERE, "Where"},
{tflite::BuiltinOperator_SPARSE_TO_DENSE, "SparseToDense"},