!7743 [MSLITE][Develop] add swish kernel

Merge pull request !7743 from sunsuodong/add_swish_kernel
This commit is contained in:
mindspore-ci-bot 2020-10-26 14:08:57 +08:00 committed by Gitee
commit 0acf9729c9
8 changed files with 60 additions and 13 deletions

View File

@ -108,6 +108,26 @@ int Tanh(const float *src, int length, float *dst) {
return NNACL_OK;
}
int Swish(const float *src, int length, float *dst) {
int ret = Sigmoid(src, length, dst);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
int index = 0;
#ifdef ENABLE_NEON
for (; index <= length - C4NUM; index += C4NUM) {
float32x4_t src_value = vld1q_f32(src + index);
float32x4_t sigmoid_value = vld1q_f32(dst + index);
float32x4_t result = vmulq_f32(src_value, sigmoid_value);
vst1q_f32(dst + index, result);
}
#endif
for (; index < length; index++) {
dst[index] = src[index] * dst[index];
}
return NNACL_OK;
}
int HSwish(const float *src, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float in = src[i];

View File

@ -37,6 +37,7 @@ int LRelu(const float *src, int length, float *dst, float alpha);
int Sigmoid(const float *src, int length, float *dst);
int Tanh(const float *src, int length, float *dst);
int HSigmoid(const float *src, int length, float *dst);
int Swish(const float *src, int length, float *dst);
int HSwish(const float *src, int length, float *dst);
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
#ifdef __cplusplus

View File

@ -79,7 +79,8 @@ enum ActivationType : byte {
LINEAR = 15,
HARD_TANH = 16,
SIGN = 17,
UNKNOW = 18
SWISH = 18,
UNKNOW = 19
}
enum ActivationGradType : byte {
NO_ACTIVATION = 0,

View File

@ -53,6 +53,8 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->type = schema::ActivationType_SIGMOID;
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
} else if (prim.name() == "Swish") {
attr->type = schema::ActivationType_SWISH;
} else if (prim.name() == "HSwish") {
attr->type = schema::ActivationType_HSWISH;
} else if (prim.name() == "HSigmoid") {

View File

@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH;
using mindspore::schema::ActivationType_LEAKY_RELU;
using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_RELU6;
using mindspore::schema::ActivationType_SWISH;
using mindspore::schema::PrimitiveType_Activation;
namespace mindspore::kernel {
@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) {
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
auto error_code = RET_OK;
auto ret = RET_OK;
if (type_ == schema::ActivationType_RELU) {
error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_RELU6) {
error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_LEAKY_RELU) {
error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
} else if (type_ == schema::ActivationType_SIGMOID) {
error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_TANH) {
error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_SWISH) {
ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HSWISH) {
error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HSIGMOID) {
error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HARD_TANH) {
error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
} else {
MS_LOG(ERROR) << "Activation type error";
return RET_ERROR;
}
if (error_code != RET_OK) {
return RET_ERROR;
if (ret != RET_OK) {
MS_LOG(ERROR) << "Activation error, ret: " << ret;
}
return RET_OK;
return ret;
}
int ActivationRun(void *cdata, int task_id) {

View File

@ -73,6 +73,17 @@ TEST_F(TestActivationFp32, SigmoidFp32) {
MS_LOG(INFO) << "TestSigmoidFp32 passed";
}
TEST_F(TestActivationFp32, SwishFp32) {
float input[8] = {0, 1, 2, 3, 4, 5, 6, 7};
float output[8] = {0};
Swish(input, 8, output);
float expect[8] = {0, 0.731059, 1.761594, 2.857722, 3.928056, 4.966535, 5.985162, 6.993623};
for (int i = 0; i < 8; ++i) {
EXPECT_NEAR(output[i], expect[i], 0.00001);
}
}
TEST_F(TestActivationFp32, TanhFp32) {
float input[7] = {-3, -2, -1, 0, 1, 2, 3};
float output[7] = {0};

View File

@ -56,6 +56,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
} else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "Swish") == 0) {
MS_LOG(DEBUG) << "parse TfliteSwishParser";
attr->type = schema::ActivationType_SWISH;
} else if (std::strcmp(node_name, "HardSwish") == 0) {
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
attr->type = schema::ActivationType_HSWISH;
@ -82,6 +85,7 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser());
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());

View File

@ -53,6 +53,11 @@ class TfliteLogisticParser : public TfliteActivationParser {
TfliteLogisticParser() : TfliteActivationParser() {}
};
class TfliteSwishParser : public TfliteActivationParser {
public:
TfliteSwishParser() : TfliteActivationParser() {}
};
class TfliteHardSwishParser : public TfliteActivationParser {
public:
TfliteHardSwishParser() : TfliteActivationParser() {}