forked from mindspore-Ecosystem/mindspore
!7743 [MSLITE][Develop] add swish kernel
Merge pull request !7743 from sunsuodong/add_swish_kernel
This commit is contained in:
commit
0acf9729c9
|
@ -108,6 +108,26 @@ int Tanh(const float *src, int length, float *dst) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int Swish(const float *src, int length, float *dst) {
|
||||||
|
int ret = Sigmoid(src, length, dst);
|
||||||
|
if (ret != NNACL_OK) {
|
||||||
|
return NNACL_ERR;
|
||||||
|
}
|
||||||
|
int index = 0;
|
||||||
|
#ifdef ENABLE_NEON
|
||||||
|
for (; index <= length - C4NUM; index += C4NUM) {
|
||||||
|
float32x4_t src_value = vld1q_f32(src + index);
|
||||||
|
float32x4_t sigmoid_value = vld1q_f32(dst + index);
|
||||||
|
float32x4_t result = vmulq_f32(src_value, sigmoid_value);
|
||||||
|
vst1q_f32(dst + index, result);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; index < length; index++) {
|
||||||
|
dst[index] = src[index] * dst[index];
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int HSwish(const float *src, int length, float *dst) {
|
int HSwish(const float *src, int length, float *dst) {
|
||||||
for (int i = 0; i < length; ++i) {
|
for (int i = 0; i < length; ++i) {
|
||||||
float in = src[i];
|
float in = src[i];
|
||||||
|
|
|
@ -37,6 +37,7 @@ int LRelu(const float *src, int length, float *dst, float alpha);
|
||||||
int Sigmoid(const float *src, int length, float *dst);
|
int Sigmoid(const float *src, int length, float *dst);
|
||||||
int Tanh(const float *src, int length, float *dst);
|
int Tanh(const float *src, int length, float *dst);
|
||||||
int HSigmoid(const float *src, int length, float *dst);
|
int HSigmoid(const float *src, int length, float *dst);
|
||||||
|
int Swish(const float *src, int length, float *dst);
|
||||||
int HSwish(const float *src, int length, float *dst);
|
int HSwish(const float *src, int length, float *dst);
|
||||||
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
|
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
@ -79,7 +79,8 @@ enum ActivationType : byte {
|
||||||
LINEAR = 15,
|
LINEAR = 15,
|
||||||
HARD_TANH = 16,
|
HARD_TANH = 16,
|
||||||
SIGN = 17,
|
SIGN = 17,
|
||||||
UNKNOW = 18
|
SWISH = 18,
|
||||||
|
UNKNOW = 19
|
||||||
}
|
}
|
||||||
enum ActivationGradType : byte {
|
enum ActivationGradType : byte {
|
||||||
NO_ACTIVATION = 0,
|
NO_ACTIVATION = 0,
|
||||||
|
|
|
@ -53,6 +53,8 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
||||||
attr->type = schema::ActivationType_SIGMOID;
|
attr->type = schema::ActivationType_SIGMOID;
|
||||||
} else if (prim.name() == "ReLU6") {
|
} else if (prim.name() == "ReLU6") {
|
||||||
attr->type = schema::ActivationType_RELU6;
|
attr->type = schema::ActivationType_RELU6;
|
||||||
|
} else if (prim.name() == "Swish") {
|
||||||
|
attr->type = schema::ActivationType_SWISH;
|
||||||
} else if (prim.name() == "HSwish") {
|
} else if (prim.name() == "HSwish") {
|
||||||
attr->type = schema::ActivationType_HSWISH;
|
attr->type = schema::ActivationType_HSWISH;
|
||||||
} else if (prim.name() == "HSigmoid") {
|
} else if (prim.name() == "HSigmoid") {
|
||||||
|
|
|
@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH;
|
||||||
using mindspore::schema::ActivationType_LEAKY_RELU;
|
using mindspore::schema::ActivationType_LEAKY_RELU;
|
||||||
using mindspore::schema::ActivationType_RELU;
|
using mindspore::schema::ActivationType_RELU;
|
||||||
using mindspore::schema::ActivationType_RELU6;
|
using mindspore::schema::ActivationType_RELU6;
|
||||||
|
using mindspore::schema::ActivationType_SWISH;
|
||||||
using mindspore::schema::PrimitiveType_Activation;
|
using mindspore::schema::PrimitiveType_Activation;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
|
@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) {
|
||||||
int stride = UP_DIV(length, thread_count_);
|
int stride = UP_DIV(length, thread_count_);
|
||||||
int count = MSMIN(stride, length - stride * task_id);
|
int count = MSMIN(stride, length - stride * task_id);
|
||||||
|
|
||||||
auto error_code = RET_OK;
|
auto ret = RET_OK;
|
||||||
|
|
||||||
if (type_ == schema::ActivationType_RELU) {
|
if (type_ == schema::ActivationType_RELU) {
|
||||||
error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||||
} else if (type_ == schema::ActivationType_RELU6) {
|
} else if (type_ == schema::ActivationType_RELU6) {
|
||||||
error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||||
} else if (type_ == schema::ActivationType_LEAKY_RELU) {
|
} else if (type_ == schema::ActivationType_LEAKY_RELU) {
|
||||||
error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
|
ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
|
||||||
} else if (type_ == schema::ActivationType_SIGMOID) {
|
} else if (type_ == schema::ActivationType_SIGMOID) {
|
||||||
error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||||
} else if (type_ == schema::ActivationType_TANH) {
|
} else if (type_ == schema::ActivationType_TANH) {
|
||||||
error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||||
|
} else if (type_ == schema::ActivationType_SWISH) {
|
||||||
|
ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||||
} else if (type_ == schema::ActivationType_HSWISH) {
|
} else if (type_ == schema::ActivationType_HSWISH) {
|
||||||
error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||||
} else if (type_ == schema::ActivationType_HSIGMOID) {
|
} else if (type_ == schema::ActivationType_HSIGMOID) {
|
||||||
error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||||
} else if (type_ == schema::ActivationType_HARD_TANH) {
|
} else if (type_ == schema::ActivationType_HARD_TANH) {
|
||||||
error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
|
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Activation type error";
|
MS_LOG(ERROR) << "Activation type error";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (error_code != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
return RET_ERROR;
|
MS_LOG(ERROR) << "Activation error, ret: " << ret;
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ActivationRun(void *cdata, int task_id) {
|
int ActivationRun(void *cdata, int task_id) {
|
||||||
|
|
|
@ -73,6 +73,17 @@ TEST_F(TestActivationFp32, SigmoidFp32) {
|
||||||
MS_LOG(INFO) << "TestSigmoidFp32 passed";
|
MS_LOG(INFO) << "TestSigmoidFp32 passed";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestActivationFp32, SwishFp32) {
|
||||||
|
float input[8] = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||||
|
float output[8] = {0};
|
||||||
|
Swish(input, 8, output);
|
||||||
|
|
||||||
|
float expect[8] = {0, 0.731059, 1.761594, 2.857722, 3.928056, 4.966535, 5.985162, 6.993623};
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
EXPECT_NEAR(output[i], expect[i], 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TestActivationFp32, TanhFp32) {
|
TEST_F(TestActivationFp32, TanhFp32) {
|
||||||
float input[7] = {-3, -2, -1, 0, 1, 2, 3};
|
float input[7] = {-3, -2, -1, 0, 1, 2, 3};
|
||||||
float output[7] = {0};
|
float output[7] = {0};
|
||||||
|
|
|
@ -56,6 +56,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||||
} else if (std::strcmp(node_name, "Logistic") == 0) {
|
} else if (std::strcmp(node_name, "Logistic") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
|
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
|
||||||
attr->type = schema::ActivationType_SIGMOID;
|
attr->type = schema::ActivationType_SIGMOID;
|
||||||
|
} else if (std::strcmp(node_name, "Swish") == 0) {
|
||||||
|
MS_LOG(DEBUG) << "parse TfliteSwishParser";
|
||||||
|
attr->type = schema::ActivationType_SWISH;
|
||||||
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
||||||
attr->type = schema::ActivationType_HSWISH;
|
attr->type = schema::ActivationType_HSWISH;
|
||||||
|
@ -82,6 +85,7 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||||
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
|
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
|
||||||
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
|
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
|
||||||
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
|
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
|
||||||
|
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser());
|
||||||
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
|
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
|
||||||
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
|
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
|
||||||
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
|
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
|
||||||
|
|
|
@ -53,6 +53,11 @@ class TfliteLogisticParser : public TfliteActivationParser {
|
||||||
TfliteLogisticParser() : TfliteActivationParser() {}
|
TfliteLogisticParser() : TfliteActivationParser() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TfliteSwishParser : public TfliteActivationParser {
|
||||||
|
public:
|
||||||
|
TfliteSwishParser() : TfliteActivationParser() {}
|
||||||
|
};
|
||||||
|
|
||||||
class TfliteHardSwishParser : public TfliteActivationParser {
|
class TfliteHardSwishParser : public TfliteActivationParser {
|
||||||
public:
|
public:
|
||||||
TfliteHardSwishParser() : TfliteActivationParser() {}
|
TfliteHardSwishParser() : TfliteActivationParser() {}
|
||||||
|
|
Loading…
Reference in New Issue