[MSLITE] int32 relu

This commit is contained in:
ling 2021-10-11 16:08:39 +08:00
parent fd0f1e3eaa
commit 1974fed359
4 changed files with 71 additions and 10 deletions

View File

@ -39,6 +39,13 @@ int Fp32Relu(const float *src, int length, float *dst) {
return NNACL_OK;
}
int Int32Relu(const int32_t *src, int length, int32_t *dst) {
for (int i = 0; i < length; ++i) {
dst[i] = src[i] > 0 ? src[i] : 0;
}
return NNACL_OK;
}
int Fp32Relu6(const float *src, int length, float *dst) {
int i = 0;

View File

@ -32,6 +32,7 @@ typedef struct ActivationParameter {
extern "C" {
#endif
int Fp32Relu(const float *src, int length, float *dst);
int Int32Relu(const int32_t *src, int length, int32_t *dst);
int Fp32Relu6(const float *src, int length, float *dst);
int LRelu(const float *src, int length, float *dst, float alpha);
int Sigmoid(const float *src, int length, float *dst);

View File

@ -36,14 +36,24 @@ namespace mindspore::kernel {
int ActivationCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 1);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
if (type_ != schema::ActivationType_RELU && type_ != schema::ActivationType_RELU6 &&
type_ != schema::ActivationType_LEAKY_RELU && type_ != schema::ActivationType_SIGMOID &&
type_ != schema::ActivationType_TANH && type_ != schema::ActivationType_HSWISH &&
type_ != schema::ActivationType_SWISH && type_ != schema::ActivationType_HSIGMOID &&
type_ != schema::ActivationType_HARD_TANH && type_ != schema::ActivationType_GELU &&
type_ != schema::ActivationType_SOFTPLUS && type_ != schema::ActivationType_ELU) {
MS_LOG(ERROR) << "Activation fp32 not support type: " << type_;
return RET_ERROR;
if (in_tensors().front()->data_type() == kNumberTypeInt32) {
if (type_ != schema::ActivationType_RELU) {
MS_LOG(ERROR) << "Activation int32 not support type: " << type_;
return RET_ERROR;
}
}
if (in_tensors().front()->data_type() == kNumberTypeFloat32) {
if (type_ != schema::ActivationType_RELU && type_ != schema::ActivationType_RELU6 &&
type_ != schema::ActivationType_LEAKY_RELU && type_ != schema::ActivationType_SIGMOID &&
type_ != schema::ActivationType_TANH && type_ != schema::ActivationType_HSWISH &&
type_ != schema::ActivationType_SWISH && type_ != schema::ActivationType_HSIGMOID &&
type_ != schema::ActivationType_HARD_TANH && type_ != schema::ActivationType_GELU &&
type_ != schema::ActivationType_SOFTPLUS && type_ != schema::ActivationType_ELU) {
MS_LOG(ERROR) << "Activation fp32 not support type: " << type_;
return RET_ERROR;
}
}
return RET_OK;
}
@ -51,6 +61,44 @@ int ActivationCPUKernel::Prepare() {
int ActivationCPUKernel::ReSize() { return RET_OK; }
int ActivationCPUKernel::DoActivation(int task_id) {
if (in_tensors_.front()->data_type() == kNumberTypeFloat32) {
return DoActivationFp32(task_id);
} else if (in_tensors_.front()->data_type() == kNumberTypeInt32) {
return DoActivationInt32(task_id);
}
return RET_ERROR;
}
int ActivationCPUKernel::DoActivationInt32(int task_id) {
auto input_addr = reinterpret_cast<int32_t *>(in_tensors_.at(0)->data());
auto output_addr = reinterpret_cast<int32_t *>(out_tensors_.at(0)->data());
MS_ASSERT(input_addr != nullptr);
MS_ASSERT(output_addr != nullptr);
auto length = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
if (count <= 0) {
return RET_OK;
}
if (INT_MUL_OVERFLOW(stride, task_id)) {
return RET_ERROR;
}
auto ret = RET_OK;
if (type_ == schema::ActivationType_RELU) {
ret = Int32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else {
MS_LOG(ERROR) << "Int32 Activation type error";
return RET_ERROR;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "Int32 Activation error, ret: " << ret;
}
return ret;
}
int ActivationCPUKernel::DoActivationFp32(int task_id) {
auto input_addr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
MS_ASSERT(input_addr != nullptr);
@ -93,11 +141,11 @@ int ActivationCPUKernel::DoActivation(int task_id) {
} else if (type_ == schema::ActivationType_ELU) {
ret = Elu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
} else {
MS_LOG(ERROR) << "Activation type error";
MS_LOG(ERROR) << "Fp32 Activation type error";
return RET_ERROR;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "Activation error, ret: " << ret;
MS_LOG(ERROR) << "Fp32 Activation error, ret: " << ret;
}
return ret;
}
@ -124,4 +172,5 @@ int ActivationCPUKernel::Run() {
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Activation, LiteKernelCreator<ActivationCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Activation, LiteKernelCreator<ActivationCPUKernel>)
} // namespace mindspore::kernel

View File

@ -39,6 +39,10 @@ class ActivationCPUKernel : public InnerKernel {
int Run() override;
int DoActivation(int task_id);
private:
int DoActivationFp32(int task_id);
int DoActivationInt32(int task_id);
private:
int thread_count_;
int type_;