add new ops named power 2020.10.29 morning

This commit is contained in:
Pengyongrong 2020-10-28 18:38:59 -07:00
parent e8b27b323f
commit ef2d59d567
4 changed files with 461 additions and 0 deletions

View File

@ -0,0 +1,77 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
#define CHECK_IDX \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w || output_shape.y == 0) { \
return; \
}
FLT OptimizedPowerImpl(FLT x, int exponent) {
int exp = abs(exponent);
FLT result = 1.0f;
FLT iterator = x;
while (exp) {
if (exp % 2) {
result *= iterator;
}
iterator *= iterator;
exp = exp / 2;
}
return exponent >= 0 ? result : 1 / result;
}
__kernel void power(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output,
int4 output_shape, FLT4 parameter) {
CHECK_IDX;
int n = X / output_shape.y;
int h = X % output_shape.y;
FLT4 result;
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));
FLT tmp_result[4];
FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w};
FLT tmp_result1[4] = {result1.x, result1.y, result1.z, result1.w};
for (int i = 0; i < 4; ++i) {
tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y;
if (floor(tmp_result1[i]) == tmp_result1[i]) {
int exponent = tmp_result1[i];
tmp_result[i] = OptimizedPowerImpl(tmp_result0[i], exponent);
} else {
tmp_result[i] = pow(tmp_result0[i], tmp_result1[i]);
}
}
result.x = tmp_result[0];
result.y = tmp_result[1];
result.z = tmp_result[2];
result.w = tmp_result[3];
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)), result);
}
__kernel void power_broadcast(__read_only image2d_t input, __write_only image2d_t output, int4 output_shape,
FLT4 parameter) {
CHECK_IDX;
int n = X / output_shape.y;
int h = X % output_shape.y;
FLT4 result;
FLT4 result0 = READ_IMAGE(input, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));
FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w};
FLT tmp_result[4];
bool flag = floor(parameter.x) == parameter.x ? false : true;
for (int i = 0; i < 4; ++i) {
tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y;
if (flag) {
int exponent = parameter.x;
tmp_result[i] = OptimizedPowerImpl(tmp_result0[i], exponent);
} else {
tmp_result[i] = pow(tmp_result0[i], parameter.x);
}
}
result.x = tmp_result[0];
result.y = tmp_result[1];
result.z = tmp_result[2];
result.w = tmp_result[3];
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)), result);
}

View File

@ -0,0 +1,160 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/opencl/kernel/power.h"
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/cl/power.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Power;
namespace mindspore::kernel {
int PowerOpenCLKernel::Init() {
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
std::string kernel_name = "power";
std::set<std::string> build_options;
std::string source = power_source;
std::string program_name = "power";
broadcast_ = param->broadcast_;
if (in_tensors_.size() == 2 && in_tensors_[0]->shape().size() != in_tensors_[1]->shape().size()) {
MS_LOG(ERROR) << "Unsupported input0->shape.size " << in_tensors_[0]->shape().size()
<< "!=" << in_tensors_[1]->shape().size();
return RET_ERROR;
} else if (in_tensors_.size() > 2 || in_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << "Unsupported in_tensors_->shape.size " << in_tensors_.size() << " or "
<< "in_tensors_[0]->shape().size(): " << in_tensors_[0]->shape().size();
return RET_ERROR;
} else if (broadcast_ && in_tensors_.size() == 1) {
power_ = param->power_;
kernel_name += "_broadcast";
}
scale_ = param->scale_;
shift_ = param->shift_;
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK;
}
void PowerGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
const int max_x = 2, max_y = 8;
int x = std::min(GetMaxDivisorStrategy1(global[0], max_divider), max_x);
int yz = max_size / x;
int y = std::min(std::min(GetMaxDivisorStrategy1(global[1], max_divider), yz), max_y);
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));
local->clear();
local->push_back(x);
local->push_back(y);
local->push_back(z);
}
int PowerOpenCLKernel::InferShapeTo4D() {
if (in_tensors_[0]->shape().size() <= 4) {
if (in_tensors_[0]->shape().size() == 1) {
N_ = in_tensors_[0]->shape()[0];
} else if (in_tensors_[0]->shape().size() == 2) {
N_ = in_tensors_[0]->shape()[0];
C_ = in_tensors_[0]->shape()[1];
} else if (in_tensors_[0]->shape().size() == 3) {
N_ = in_tensors_[0]->shape()[0];
W_ = in_tensors_[0]->shape()[1];
C_ = in_tensors_[0]->shape()[2];
} else {
N_ = in_tensors_[0]->shape()[0];
H_ = in_tensors_[0]->shape()[1];
W_ = in_tensors_[0]->shape()[2];
C_ = in_tensors_[0]->shape()[3];
}
} else {
MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size();
return RET_ERROR;
}
return RET_OK;
}
int PowerOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto output_shape = out_tensors_[0]->shape();
InferShapeTo4D();
cl_int4 output_shape_ = {static_cast<cl_int>(N_), static_cast<cl_int>(H_), static_cast<cl_int>(W_),
static_cast<cl_int>(UP_DIV(C_, C4NUM))};
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1};
uint32_t OH = N_ * H_;
uint32_t OW = W_;
uint32_t OC = UP_DIV(C_, C4NUM);
std::vector<size_t> global = {OH, OW, OC};
PowerGetWorkGroup(global, &local, max_global[0]);
int arg_cn = 0;
if (broadcast_) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
if (use_fp16_enable_) {
auto x = static_cast<float16_t>(power_);
auto y = static_cast<float16_t>(shift_);
auto z = static_cast<float16_t>(scale_);
cl_half4 parameter = {*(reinterpret_cast<uint16_t *>(&x)), *(reinterpret_cast<uint16_t *>(&y)),
*(reinterpret_cast<uint16_t *>(&z)), 1};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, parameter);
} else {
cl_float4 parameter = {power_, shift_, scale_, 1};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, parameter);
}
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}
kernel::LiteKernel *PowerOpenCLKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) PowerOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << " new PowerOpenCLKernel failed ";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << " Init kernel failed, name: Power ";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Power, PowerOpenCLKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Power, PowerOpenCLKernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,55 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_POWER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_POWER_H_
#include <vector>
#include "nnacl/power.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
namespace mindspore::kernel {
class PowerOpenCLKernel : public OpenCLKernel {
public:
PowerOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~PowerOpenCLKernel() override = default;
int Init() override;
int Run() override;
private:
int InferShapeTo4D();
cl::Kernel kernel_;
private:
size_t N_{1};
size_t H_{1};
size_t W_{1};
size_t C_{1};
bool broadcast_{false};
bool use_fp16_enable_{false};
float power_{1.0};
float scale_{0.0};
float shift_{1.0};
};
} // namespace mindspore::kernel
#endif

View File

@ -0,0 +1,169 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/power.h"
using mindspore::lite::Tensor;
using mindspore::schema::Format::Format_NHWC;
namespace mindspore {
class TestPowerOpenCLCI : public mindspore::CommonTest {
public:
TestPowerOpenCLCI() {}
};
template <class T>
void CompareData(const T *output_data, const T *correct_data, int size, float err_bound) {
for (int i = 0; i < size; i++) {
T abs = fabs(output_data[i] - correct_data[i]);
ASSERT_LE(abs, err_bound);
}
}
template <class T>
void TEST_MAIN(const T *input_data1, const T *input_data2, const T *expect_data, const TypeId data_type,
const std::vector<int> &shape_a, const std::vector<int> &shape_b, const std::vector<int> &out_shape,
bool broadcast, const T scale = 1.0, const T shift = 0, const T exponent = 2) {
MS_LOG(INFO) << " begin test ";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->Init();
if (data_type == kNumberTypeFloat16) {
runtime->SetFp16Enable(true);
}
auto allocator = runtime->GetAllocator();
auto tensor_type = lite::Tensor::CONST_TENSOR;
auto in_tensor1 = Tensor(data_type, shape_a, Format_NHWC, tensor_type);
auto in_tensor2 = Tensor(data_type, shape_b, Format_NHWC, tensor_type);
auto output_tensor = Tensor(data_type, out_shape, Format_NHWC, tensor_type);
MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ActivationParameter failed ";
return;
}
param->scale_ = scale;
param->shift_ = shift;
std::vector<lite::Tensor *> inputs;
std::vector<lite::Tensor *> outputs{&output_tensor};
if (broadcast) {
param->broadcast_ = true;
inputs.push_back(&in_tensor1);
param->power_ = exponent;
} else {
inputs.push_back(&in_tensor1);
inputs.push_back(&in_tensor2);
}
auto *power_kernel =
new (std::nothrow) kernel::PowerOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (power_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::PowerOpenCLKernel failed ";
delete param;
return;
}
power_kernel->Init();
// to do allocate memory for inputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{power_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
delete param;
delete power_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
size_t size = 1 * sizeof(T);
for (int i = 0; i < out_shape.size(); ++i) {
size *= out_shape[i];
}
if (broadcast) {
memcpy(inputs[0]->data_c(), input_data1, size);
} else {
memcpy(inputs[0]->data_c(), input_data1, size);
memcpy(inputs[1]->data_c(), input_data2, size);
}
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
T *output_data_gpu = reinterpret_cast<T *>(output_tensor.data_c());
CompareData(output_data_gpu, expect_data, output_tensor.ElementsNum(), 0.0001);
delete sub_graph;
}
TEST_F(TestPowerOpenCLCI, Int32CI) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {1, 2, 8};
std::vector<int> shape_b = {1, 2, 8};
std::vector<int> output_shape = {1, 2, 8};
auto data_type = kNumberTypeFloat32;
const float input_data1[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0};
const float input_data2[] = {2, 2, 2, 1, 2, 2, 3, 3, 2, 2, 3, 0, 2, 2, 1, 2};
const float expect_data[] = {4.0, 9.0, 16.0, 5.0, 36.0, 49.0, 512, 729,
100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0};
TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false);
}
TEST_F(TestPowerOpenCLCI, Fp32CI) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {2, 8};
std::vector<int> shape_b = {2, 8};
std::vector<int> output_shape = {2, 8};
auto data_type = kNumberTypeFloat32;
const float input_data1[] = {0.78957046, -0.99770847, 1.05838929, 1.60738329, -1.66226552, -2.03170525,
-0.48257631, -0.94244638, 1.47462044, -0.80247114, 0.12354778, -0.36436107,
-2.41973013, -0.40221205, -0.26739485, 0.23298305};
const float input_data2[] = {3, 2, 2, 1, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2};
const float expect_data[] = {0.49223521, 0.99542219, 1.12018788, 1.60738329, 2.76312667, 4.1278262,
0.23287989, 0.88820518, 3.20657016, 0.64395994, 0.01526405, 0.13275899,
5.85509388, 0.16177453, 0.07150001, 0.0542811};
TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false);
}
TEST_F(TestPowerOpenCLCI, Fp16CI) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {2, 8};
std::vector<int> shape_b = {2, 8};
std::vector<int> output_shape = {2, 8};
auto data_type = kNumberTypeFloat16;
const float16_t input_data1[] = {0.1531, -0.8003, -0.1848, 0.3833, -1.469, 0.5586, -0.3223, -0.8887,
0.697, -1.007, -0.45, -1.736, -0.462, -0.699, -0.596, 0.7466};
const float16_t input_data2[] = {2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0};
const float16_t expect_data[] = {0.02344, -0.8003, -0.1848, 0.147, 2.156, 0.312, 0.1039, 0.7896,
0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466};
TEST_MAIN(input_data1, input_data2, expect_data, data_type, shape_a, shape_b, output_shape, false);
}
TEST_F(TestPowerOpenCLCI, broadcast) {
MS_LOG(INFO) << " init tensors ";
std::vector<int> shape_a = {1, 2, 8};
std::vector<int> shape_b = {};
std::vector<int> output_shape = {1, 2, 8};
auto data_type = kNumberTypeFloat32;
float input_data1[] = {2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0};
float expect_data[] = {4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64, 81, 100.0, 121.0, 144, 169, 196.0, 225.0, 256, 289.0};
TEST_MAIN(input_data1, input_data1, expect_data, data_type, shape_a, shape_b, output_shape, true);
}
} // namespace mindspore