forked from mindspore-Ecosystem/mindspore
!7969 [MS][LITE][Develop] add new activation func named swish
Merge pull request !7969 from pengyongrong/stack
This commit is contained in:
commit
88aaec279c
|
@ -54,3 +54,12 @@ __kernel void Tanh(__read_only image2d_t input, __write_only image2d_t output, c
|
|||
in_c4 = (exp0 - exp1) / (exp0 + exp1);
|
||||
WRITE_IMAGE(output, (int2)(X, Y), in_c4);
|
||||
}
|
||||
|
||||
__kernel void Swish(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
if (X >= img_shape.x || Y >= img_shape.y) return;
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));
|
||||
in_c4 = in_c4 * ((FLT4)(1.f) / ((FLT4)(1.f) + exp(-in_c4)));
|
||||
WRITE_IMAGE(output, (int2)(X, Y), in_c4);
|
||||
}
|
||||
|
|
|
@ -35,17 +35,16 @@ using mindspore::schema::ActivationType_LEAKY_RELU;
|
|||
using mindspore::schema::ActivationType_RELU;
|
||||
using mindspore::schema::ActivationType_RELU6;
|
||||
using mindspore::schema::ActivationType_SIGMOID;
|
||||
using mindspore::schema::ActivationType_SWISH;
|
||||
using mindspore::schema::ActivationType_TANH;
|
||||
using mindspore::schema::PrimitiveType_Activation;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int ActivationOpenClKernel::Init() {
|
||||
std::map<int, std::string> kernel_names{{ActivationType_LEAKY_RELU, "LeakyRelu"},
|
||||
{ActivationType_RELU, "Relu"},
|
||||
{ActivationType_SIGMOID, "Sigmoid"},
|
||||
{ActivationType_RELU6, "Relu6"},
|
||||
{ActivationType_TANH, "Tanh"}};
|
||||
std::map<int, std::string> kernel_names{
|
||||
{ActivationType_LEAKY_RELU, "LeakyRelu"}, {ActivationType_RELU, "Relu"}, {ActivationType_SIGMOID, "Sigmoid"},
|
||||
{ActivationType_RELU6, "Relu6"}, {ActivationType_TANH, "Tanh"}, {ActivationType_SWISH, "Swish"}};
|
||||
if (kernel_names.count(type_) == 0) {
|
||||
MS_LOG(ERROR) << "schema::ActivationType:" << type_ << "not found";
|
||||
return mindspore::lite::RET_ERROR;
|
||||
|
|
|
@ -28,10 +28,12 @@ using mindspore::kernel::LiteKernel;
|
|||
using mindspore::kernel::SubGraphOpenCLKernel;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::Tensor;
|
||||
using mindspore::schema::ActivationType_LEAKY_RELU;
|
||||
using mindspore::schema::ActivationType_RELU;
|
||||
using mindspore::schema::ActivationType_RELU6;
|
||||
using mindspore::schema::ActivationType_SIGMOID;
|
||||
using mindspore::schema::ActivationType_SWISH;
|
||||
using mindspore::schema::ActivationType_TANH;
|
||||
using mindspore::schema::PrimitiveType_Activation;
|
||||
|
||||
|
@ -619,4 +621,86 @@ TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) {
|
|||
delete output_tensor;
|
||||
delete sub_graph;
|
||||
}
|
||||
|
||||
TEST_F(TestActivationOpenCL, SwishFp_dim4) {
|
||||
size_t input_size;
|
||||
std::string in_file = "/data/local/tmp/test_data/in_swishfp16.bin";
|
||||
std::string out_file = "/data/local/tmp/test_data/out_swishfp16.bin";
|
||||
auto input_data = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(in_file.c_str(), &input_size));
|
||||
MS_LOG(INFO) << "Swish Begin test!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper();
|
||||
auto runtime = ocl_runtime.GetInstance();
|
||||
runtime->Init();
|
||||
auto data_type = kNumberTypeFloat16;
|
||||
runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
|
||||
bool enable_fp16 = runtime->GetFp16Enable();
|
||||
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 2, 3, 9};
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
auto tensor_type = lite::Tensor::CONST_TENSOR;
|
||||
auto input_tensor = Tensor(data_type, input_shape, format, tensor_type);
|
||||
auto output_tensor = Tensor(data_type, input_shape, format, tensor_type);
|
||||
|
||||
std::vector<lite::Tensor *> inputs{&input_tensor};
|
||||
std::vector<lite::Tensor *> outputs{&output_tensor};
|
||||
auto allocator = runtime->GetAllocator();
|
||||
inputs[0]->MallocData(allocator);
|
||||
MS_LOG(INFO) << "Initialize input data";
|
||||
memcpy(inputs[0]->data_c(), input_data, input_size);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("Swish:FP16--input data--", inputs[0]);
|
||||
} else {
|
||||
printf_tensor<float>("Swish:FP32--input data--", inputs[0]);
|
||||
}
|
||||
|
||||
auto param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "New ActivationParameter fail.";
|
||||
return;
|
||||
}
|
||||
param->type_ = ActivationType_SWISH;
|
||||
auto *kernel =
|
||||
new (std::nothrow) kernel::ActivationOpenClKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Kernel:Swish create fail.";
|
||||
delete param;
|
||||
return;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete param;
|
||||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init Swish fail.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Create kernel SubGraphOpenCLKernel.";
|
||||
std::vector<kernel::LiteKernel *> kernels{kernel};
|
||||
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
if (sub_graph == nullptr) {
|
||||
delete kernel;
|
||||
delete param;
|
||||
MS_LOG(ERROR) << "Kernel SubGraphOpenCLKernel create fail.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Initialize sub_graph.";
|
||||
ret = sub_graph->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init sub_graph error.";
|
||||
delete sub_graph;
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Run SubGraphOpenCLKernel.";
|
||||
ret = sub_graph->Run();
|
||||
if (ret != RET_OK) {
|
||||
delete param;
|
||||
delete sub_graph;
|
||||
MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error.";
|
||||
return;
|
||||
}
|
||||
CompareRes<float16_t>(&output_tensor, out_file);
|
||||
delete sub_graph;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue