diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/power.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/power.cl new file mode 100644 index 00000000000..a4a59efdf91 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/power.cl @@ -0,0 +1,77 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; +#define CHECK_IDX \ + int X = get_global_id(0); \ + int Y = get_global_id(1); \ + int Z = get_global_id(2); \ + if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w || output_shape.y == 0) { \ + return; \ + } + +FLT OptimizedPowerImpl(FLT x, int exponent) { + int exp = abs(exponent); + FLT result = 1.0f; + FLT iterator = x; + while (exp) { + if (exp % 2) { + result *= iterator; + } + iterator *= iterator; + exp = exp / 2; + } + return exponent >= 0 ? result : 1 / result; +} + +__kernel void power(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output, + int4 output_shape, FLT4 parameter) { + CHECK_IDX; + int n = X / output_shape.y; + int h = X % output_shape.y; + FLT4 result; + FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h))); + FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h))); + FLT tmp_result[4]; + FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w}; + FLT tmp_result1[4] = {result1.x, result1.y, result1.z, result1.w}; + for (int i = 0; i < 4; ++i) { + tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y; + if (floor(tmp_result1[i]) == tmp_result1[i]) { + int exponent = tmp_result1[i]; + tmp_result[i] = OptimizedPowerImpl(tmp_result0[i], exponent); + } else { + tmp_result[i] = pow(tmp_result0[i], tmp_result1[i]); + } + } + result.x = tmp_result[0]; + result.y = tmp_result[1]; + result.z = tmp_result[2]; + result.w = tmp_result[3]; + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)), result); +} + +__kernel void power_broadcast(__read_only image2d_t input, __write_only image2d_t output, int4 output_shape, + FLT4 parameter) { + CHECK_IDX; + int n = X / output_shape.y; + int h = X % output_shape.y; + FLT4 result; + FLT4 result0 = READ_IMAGE(input, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h))); + FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w}; + FLT tmp_result[4]; + + bool flag = floor(parameter.x) == parameter.x ? false : true; + for (int i = 0; i < 4; ++i) { + tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y; + if (flag) { + int exponent = parameter.x; + tmp_result[i] = OptimizedPowerImpl(tmp_result0[i], exponent); + } else { + tmp_result[i] = pow(tmp_result0[i], parameter.x); + } + } + result.x = tmp_result[0]; + result.y = tmp_result[1]; + result.z = tmp_result[2]; + result.w = tmp_result[3]; + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)), result); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc new file mode 100644 index 00000000000..7e0fe62409d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/power.h" +#include +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/kernel/opencl/cl/power.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Power; + +namespace mindspore::kernel { + +int PowerOpenCLKernel::Init() { + use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); + auto param = reinterpret_cast(this->op_parameter_); + std::string kernel_name = "power"; + std::set build_options; + std::string source = power_source; + std::string program_name = "power"; + broadcast_ = param->broadcast_; + + if (in_tensors_.size() == 2 && in_tensors_[0]->shape().size() != in_tensors_[1]->shape().size()) { + MS_LOG(ERROR) << "Unsupported input0->shape.size " << in_tensors_[0]->shape().size() + << "!=" << in_tensors_[1]->shape().size(); + return RET_ERROR; + } else if (in_tensors_.size() > 2 || in_tensors_[0]->shape().size() > 4) { + MS_LOG(ERROR) << "Unsupported in_tensors_->shape.size " << in_tensors_.size() << " or " + << "in_tensors_[0]->shape().size(): " << in_tensors_[0]->shape().size(); + return RET_ERROR; + } else if (broadcast_ && in_tensors_.size() == 1) { + power_ = param->power_; + kernel_name += "_broadcast"; + } + scale_ = param->scale_; + shift_ = param->shift_; + ocl_runtime_->LoadSource(program_name, source); + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return RET_OK; +} + +void PowerGetWorkGroup(const std::vector &global, std::vector *local, int max_size) { + const int max_divider = 8; + const int max_x = 2, max_y = 8; + int x = std::min(GetMaxDivisorStrategy1(global[0], max_divider), max_x); + int yz = max_size / x; + int y = std::min(std::min(GetMaxDivisorStrategy1(global[1], max_divider), yz), max_y); + int z = std::min(yz / y, static_cast(UP_DIV(global[2], 2))); + + local->clear(); + local->push_back(x); + local->push_back(y); + local->push_back(z); +} + +int PowerOpenCLKernel::InferShapeTo4D() { + if (in_tensors_[0]->shape().size() <= 4) { + if (in_tensors_[0]->shape().size() == 1) { + N_ = in_tensors_[0]->shape()[0]; + } else if (in_tensors_[0]->shape().size() == 2) { + N_ = in_tensors_[0]->shape()[0]; + C_ = in_tensors_[0]->shape()[1]; + } else if (in_tensors_[0]->shape().size() == 3) { + N_ = in_tensors_[0]->shape()[0]; + W_ = in_tensors_[0]->shape()[1]; + C_ = in_tensors_[0]->shape()[2]; + } else { + N_ = in_tensors_[0]->shape()[0]; + H_ = in_tensors_[0]->shape()[1]; + W_ = in_tensors_[0]->shape()[2]; + C_ = in_tensors_[0]->shape()[3]; + } + } else { + MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size(); + return RET_ERROR; + } + return RET_OK; +} + +int PowerOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running! "; + auto output_shape = out_tensors_[0]->shape(); + InferShapeTo4D(); + cl_int4 output_shape_ = {static_cast(N_), static_cast(H_), static_cast(W_), + static_cast(UP_DIV(C_, C4NUM))}; + const std::vector &max_global = ocl_runtime_->GetWorkItemSize(); + std::vector local = {1, 1, 1}; + uint32_t OH = N_ * H_; + uint32_t OW = W_; + uint32_t OC = UP_DIV(C_, C4NUM); + std::vector global = {OH, OW, OC}; + PowerGetWorkGroup(global, &local, max_global[0]); + int arg_cn = 0; + if (broadcast_) { + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c()); + } else { + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c()); + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c()); + } + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c()); + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_); + + if (use_fp16_enable_) { + auto x = static_cast(power_); + auto y = static_cast(shift_); + auto z = static_cast(scale_); + cl_half4 parameter = {*(reinterpret_cast(&x)), *(reinterpret_cast(&y)), + *(reinterpret_cast(&z)), 1}; + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, parameter); + } else { + cl_float4 parameter = {power_, shift_, scale_, 1}; + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, parameter); + } + + ocl_runtime_->RunKernel(kernel_, global, local, nullptr); + return RET_OK; +} + +kernel::LiteKernel *PowerOpenCLKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) PowerOpenCLKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << " new PowerOpenCLKernel failed "; + free(opParameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << " Init kernel failed, name: Power "; + delete kernel; + return nullptr; + } + return kernel; +} +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Power, PowerOpenCLKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Power, PowerOpenCLKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/power.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.h new file mode 100644 index 00000000000..7927b924c54 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_POWER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_POWER_H_ + +#include +#include "nnacl/power.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" + +namespace mindspore::kernel { + +class PowerOpenCLKernel : public OpenCLKernel { + public: + PowerOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + + ~PowerOpenCLKernel() override = default; + + int Init() override; + + int Run() override; + + private: + int InferShapeTo4D(); + cl::Kernel kernel_; + + private: + size_t N_{1}; + size_t H_{1}; + size_t W_{1}; + size_t C_{1}; + bool broadcast_{false}; + bool use_fp16_enable_{false}; + float power_{1.0}; + float scale_{0.0}; + float shift_{1.0}; +}; + +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/power_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/power_tests.cc new file mode 100644 index 00000000000..2308c0ebc84 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/power_tests.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/common/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/power.h" + +using mindspore::lite::Tensor; +using mindspore::schema::Format::Format_NHWC; +namespace mindspore { +class TestPowerOpenCLCI : public mindspore::CommonTest { + public: + TestPowerOpenCLCI() {} +}; +template +void CompareData(const T *output_data, const T *correct_data, int size, float err_bound) { + for (int i = 0; i < size; i++) { + T abs = fabs(output_data[i] - correct_data[i]); + ASSERT_LE(abs, err_bound); + } +} +template +void TEST_MAIN(const T *input_data1, const T *input_data2, const T *expect_data, const TypeId data_type, + const std::vector &shape_a, const std::vector &shape_b, const std::vector &out_shape, + bool broadcast, const T scale = 1.0, const T shift = 0, const T exponent = 2) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + if (data_type == kNumberTypeFloat16) { + runtime->SetFp16Enable(true); + } + auto allocator = runtime->GetAllocator(); + auto tensor_type = lite::Tensor::CONST_TENSOR; + + auto in_tensor1 = Tensor(data_type, shape_a, Format_NHWC, tensor_type); + auto in_tensor2 = Tensor(data_type, shape_b, Format_NHWC, tensor_type); + auto output_tensor = Tensor(data_type, out_shape, Format_NHWC, tensor_type); + + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(PowerParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ActivationParameter failed "; + return; + } + param->scale_ = scale; + param->shift_ = shift; + std::vector inputs; + std::vector outputs{&output_tensor}; + if (broadcast) { + param->broadcast_ = true; + inputs.push_back(&in_tensor1); + param->power_ = exponent; + } else { + inputs.push_back(&in_tensor1); + inputs.push_back(&in_tensor2); + } + auto *power_kernel = + new (std::nothrow) kernel::PowerOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (power_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::PowerOpenCLKernel failed "; + delete param; + return; + } + power_kernel->Init(); + // to do allocate memory for inputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{power_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete power_kernel; + return; + } + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + size_t size = 1 * sizeof(T); + for (int i = 0; i < out_shape.size(); ++i) { + size *= out_shape[i]; + } + if (broadcast) { + memcpy(inputs[0]->data_c(), input_data1, size); + } else { + memcpy(inputs[0]->data_c(), input_data1, size); + memcpy(inputs[1]->data_c(), input_data2, size); + } + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + T *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareData(output_data_gpu, expect_data, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +TEST_F(TestPowerOpenCLCI, Int32CI) { + MS_LOG(INFO) << " init tensors "; + std::vector shape_a = {1, 2, 8}; + std::vector shape_b = {1, 2, 8}; + std::vector output_shape = {1, 2, 8}; + auto data_type = kNumberTypeFloat32; + const float input_data1[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0}; + const float input_data2[] = {2, 2, 2, 1, 2, 2, 3, 3, 2, 2, 3, 0, 2, 2, 1, 2}; + const float expect_data[] = {4.0, 9.0, 16.0, 5.0, 36.0, 49.0, 512, 729, + 100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0}; + TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false); +} + +TEST_F(TestPowerOpenCLCI, Fp32CI) { + MS_LOG(INFO) << " init tensors "; + std::vector shape_a = {2, 8}; + std::vector shape_b = {2, 8}; + std::vector output_shape = {2, 8}; + auto data_type = kNumberTypeFloat32; + const float input_data1[] = {0.78957046, -0.99770847, 1.05838929, 1.60738329, -1.66226552, -2.03170525, + -0.48257631, -0.94244638, 1.47462044, -0.80247114, 0.12354778, -0.36436107, + -2.41973013, -0.40221205, -0.26739485, 0.23298305}; + const float input_data2[] = {3, 2, 2, 1, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2}; + const float expect_data[] = {0.49223521, 0.99542219, 1.12018788, 1.60738329, 2.76312667, 4.1278262, + 0.23287989, 0.88820518, 3.20657016, 0.64395994, 0.01526405, 0.13275899, + 5.85509388, 0.16177453, 0.07150001, 0.0542811}; + TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false); +} + +TEST_F(TestPowerOpenCLCI, Fp16CI) { + MS_LOG(INFO) << " init tensors "; + std::vector shape_a = {2, 8}; + std::vector shape_b = {2, 8}; + std::vector output_shape = {2, 8}; + auto data_type = kNumberTypeFloat16; + const float16_t input_data1[] = {0.1531, -0.8003, -0.1848, 0.3833, -1.469, 0.5586, -0.3223, -0.8887, + 0.697, -1.007, -0.45, -1.736, -0.462, -0.699, -0.596, 0.7466}; + const float16_t input_data2[] = {2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0}; + const float16_t expect_data[] = {0.02344, -0.8003, -0.1848, 0.147, 2.156, 0.312, 0.1039, 0.7896, + 0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466}; + TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false); +} + +TEST_F(TestPowerOpenCLCI, broadcast) { + MS_LOG(INFO) << " init tensors "; + std::vector shape_a = {1, 2, 8}; + std::vector shape_b = {}; + std::vector output_shape = {1, 2, 8}; + auto data_type = kNumberTypeFloat32; + float input_data1[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0}; + float expect_data[] = {4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64, 81, 100.0, 121.0, 144, 169, 196.0, 225.0, 256, 289.0}; + TEST_MAIN(input_data1, input_data1, expect_data, data_type, shape_a, shape_b, output_shape, true); +} + +} // namespace mindspore