forked from mindspore-Ecosystem/mindspore
!7743 [MSLITE][Develop] add swish kernel
Merge pull request !7743 from sunsuodong/add_swish_kernel
This commit is contained in:
commit
0acf9729c9
|
@ -108,6 +108,26 @@ int Tanh(const float *src, int length, float *dst) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Swish(const float *src, int length, float *dst) {
|
||||
int ret = Sigmoid(src, length, dst);
|
||||
if (ret != NNACL_OK) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; index <= length - C4NUM; index += C4NUM) {
|
||||
float32x4_t src_value = vld1q_f32(src + index);
|
||||
float32x4_t sigmoid_value = vld1q_f32(dst + index);
|
||||
float32x4_t result = vmulq_f32(src_value, sigmoid_value);
|
||||
vst1q_f32(dst + index, result);
|
||||
}
|
||||
#endif
|
||||
for (; index < length; index++) {
|
||||
dst[index] = src[index] * dst[index];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int HSwish(const float *src, int length, float *dst) {
|
||||
for (int i = 0; i < length; ++i) {
|
||||
float in = src[i];
|
||||
|
|
|
@ -37,6 +37,7 @@ int LRelu(const float *src, int length, float *dst, float alpha);
|
|||
int Sigmoid(const float *src, int length, float *dst);
|
||||
int Tanh(const float *src, int length, float *dst);
|
||||
int HSigmoid(const float *src, int length, float *dst);
|
||||
int Swish(const float *src, int length, float *dst);
|
||||
int HSwish(const float *src, int length, float *dst);
|
||||
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -79,7 +79,8 @@ enum ActivationType : byte {
|
|||
LINEAR = 15,
|
||||
HARD_TANH = 16,
|
||||
SIGN = 17,
|
||||
UNKNOW = 18
|
||||
SWISH = 18,
|
||||
UNKNOW = 19
|
||||
}
|
||||
enum ActivationGradType : byte {
|
||||
NO_ACTIVATION = 0,
|
||||
|
|
|
@ -53,6 +53,8 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
|||
attr->type = schema::ActivationType_SIGMOID;
|
||||
} else if (prim.name() == "ReLU6") {
|
||||
attr->type = schema::ActivationType_RELU6;
|
||||
} else if (prim.name() == "Swish") {
|
||||
attr->type = schema::ActivationType_SWISH;
|
||||
} else if (prim.name() == "HSwish") {
|
||||
attr->type = schema::ActivationType_HSWISH;
|
||||
} else if (prim.name() == "HSigmoid") {
|
||||
|
|
|
@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH;
|
|||
using mindspore::schema::ActivationType_LEAKY_RELU;
|
||||
using mindspore::schema::ActivationType_RELU;
|
||||
using mindspore::schema::ActivationType_RELU6;
|
||||
using mindspore::schema::ActivationType_SWISH;
|
||||
using mindspore::schema::PrimitiveType_Activation;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) {
|
|||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
auto error_code = RET_OK;
|
||||
auto ret = RET_OK;
|
||||
|
||||
if (type_ == schema::ActivationType_RELU) {
|
||||
error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_RELU6) {
|
||||
error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_LEAKY_RELU) {
|
||||
error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
|
||||
ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
|
||||
} else if (type_ == schema::ActivationType_SIGMOID) {
|
||||
error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_TANH) {
|
||||
error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_SWISH) {
|
||||
ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_HSWISH) {
|
||||
error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_HSIGMOID) {
|
||||
error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
} else if (type_ == schema::ActivationType_HARD_TANH) {
|
||||
error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
|
||||
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Activation error, ret: " << ret;
|
||||
}
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ActivationRun(void *cdata, int task_id) {
|
||||
|
|
|
@ -73,6 +73,17 @@ TEST_F(TestActivationFp32, SigmoidFp32) {
|
|||
MS_LOG(INFO) << "TestSigmoidFp32 passed";
|
||||
}
|
||||
|
||||
TEST_F(TestActivationFp32, SwishFp32) {
|
||||
float input[8] = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
float output[8] = {0};
|
||||
Swish(input, 8, output);
|
||||
|
||||
float expect[8] = {0, 0.731059, 1.761594, 2.857722, 3.928056, 4.966535, 5.985162, 6.993623};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
EXPECT_NEAR(output[i], expect[i], 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestActivationFp32, TanhFp32) {
|
||||
float input[7] = {-3, -2, -1, 0, 1, 2, 3};
|
||||
float output[7] = {0};
|
||||
|
|
|
@ -56,6 +56,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
} else if (std::strcmp(node_name, "Logistic") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
} else if (std::strcmp(node_name, "Swish") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSwishParser";
|
||||
attr->type = schema::ActivationType_SWISH;
|
||||
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
||||
attr->type = schema::ActivationType_HSWISH;
|
||||
|
@ -82,6 +85,7 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
|
||||
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
|
||||
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
|
||||
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser());
|
||||
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
|
||||
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
|
||||
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
|
||||
|
|
|
@ -53,6 +53,11 @@ class TfliteLogisticParser : public TfliteActivationParser {
|
|||
TfliteLogisticParser() : TfliteActivationParser() {}
|
||||
};
|
||||
|
||||
class TfliteSwishParser : public TfliteActivationParser {
|
||||
public:
|
||||
TfliteSwishParser() : TfliteActivationParser() {}
|
||||
};
|
||||
|
||||
class TfliteHardSwishParser : public TfliteActivationParser {
|
||||
public:
|
||||
TfliteHardSwishParser() : TfliteActivationParser() {}
|
||||
|
|
Loading…
Reference in New Issue