forked from mindspore-Ecosystem/mindspore
add new ops named slice
This commit is contained in:
parent
11e670c54b
commit
fbd2ee53fd
|
@ -0,0 +1,81 @@
|
|||
#define INT2 int2
|
||||
#define INT4 int4
|
||||
#define FLT4 float4
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
__kernel void slice(__read_only image2d_t input, __write_only image2d_t output, INT4 input_shape, INT4 out_shape,
|
||||
INT4 begin, INT2 sharedNoUpdiv) {
|
||||
int X = get_global_id(1); // H
|
||||
int Y = get_global_id(2); // W
|
||||
if (X >= out_shape.y || Y >= out_shape.z) {
|
||||
return;
|
||||
}
|
||||
FLT4 result;
|
||||
if (sharedNoUpdiv.x % 4 == 0) {
|
||||
for (int i = 0; i < out_shape.w; i++) {
|
||||
result = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (i + begin.w), (X + begin.y)));
|
||||
write_imagef(output, (INT2)((Y)*out_shape.w + i, (X)), result);
|
||||
}
|
||||
} else {
|
||||
int begin_postion = sharedNoUpdiv.y % 4;
|
||||
FLT4 first = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + begin.w, (X + begin.y)));
|
||||
if (begin_postion == 1) {
|
||||
for (int i = 1; i <= out_shape.w; i++) {
|
||||
FLT4 second =
|
||||
read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y)));
|
||||
result.x = first.y;
|
||||
result.y = first.z;
|
||||
result.z = first.w;
|
||||
result.w = second.x;
|
||||
write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result);
|
||||
first.y = second.y;
|
||||
first.z = second.z;
|
||||
first.w = second.w;
|
||||
}
|
||||
} else if (begin_postion == 2) {
|
||||
for (int i = 1; i <= out_shape.w; i++) {
|
||||
FLT4 second =
|
||||
read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y)));
|
||||
result.x = first.z;
|
||||
result.y = first.w;
|
||||
result.z = second.x;
|
||||
result.w = second.y;
|
||||
write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result);
|
||||
first.z = second.z;
|
||||
first.w = second.w;
|
||||
}
|
||||
} else {
|
||||
for (int i = 1; i <= out_shape.w; i++) {
|
||||
FLT4 second =
|
||||
read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y)));
|
||||
result.x = first.w;
|
||||
result.y = second.x;
|
||||
result.z = second.y;
|
||||
result.w = second.z;
|
||||
write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result);
|
||||
first.w = second.w;
|
||||
}
|
||||
}
|
||||
}
|
||||
// judge the line of size
|
||||
int size = sharedNoUpdiv.y % 4;
|
||||
FLT4 result_fill0;
|
||||
if (size == 1) {
|
||||
result_fill0.x = result.x;
|
||||
result_fill0.y = 0;
|
||||
result_fill0.z = 0;
|
||||
result_fill0.w = 0;
|
||||
write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0);
|
||||
} else if (size == 2) {
|
||||
result_fill0.x = result.x;
|
||||
result_fill0.y = result.y;
|
||||
result_fill0.z = 0;
|
||||
result_fill0.w = 0;
|
||||
write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0);
|
||||
} else if (size == 3) {
|
||||
result_fill0.x = result.x;
|
||||
result_fill0.y = result.y;
|
||||
result_fill0.z = result.z;
|
||||
result_fill0.w = 0;
|
||||
write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0);
|
||||
}
|
||||
}
|
|
@ -137,7 +137,7 @@ kernel::LiteKernel *OpenCLBatchnormKernelCreator(const std::vector<lite::tensor:
|
|||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
|
|
|
@ -214,7 +214,7 @@ kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Te
|
|||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
|
|
|
@ -273,7 +273,7 @@ int ConvolutionOpenCLKernel::Run() {
|
|||
}
|
||||
|
||||
if (use_winograd_) {
|
||||
ocl_runtime->RunKernel(kernel_4x4to36, {size_t(TILES_XY), 6, size_t(CI_SLICES)}, {16, 6, 4}, nullptr);
|
||||
ocl_runtime->RunKernel(kernel_4x4to36, {size_t(TILES_XY), 6, size_t(CI_SLICES)}, {8, 6, 4}, nullptr);
|
||||
ocl_runtime->RunKernel(kernel_conv, {size_t(TILES_XY / 2), 36, size_t(CO_SLICES / 2)}, {8, 6, 2}, nullptr);
|
||||
ocl_runtime->RunKernel(kernel_36to4x4, {size_t(TILES_XY), 4, size_t(CO_SLICES)}, {32, 4, 2}, nullptr);
|
||||
} else {
|
||||
|
@ -674,7 +674,7 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::tenso
|
|||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/slice.h"
|
||||
#include "src/runtime/kernel/opencl/cl/slice.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_Slice;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int SliceOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
size_t im_dst_x, im_dst_y;
|
||||
if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
|
||||
im_dst_x = out_tensors_[0]->Width() * CO4;
|
||||
im_dst_y = out_tensors_[0]->Height();
|
||||
} else {
|
||||
im_dst_y = out_tensors_[0]->Height() * CO4;
|
||||
im_dst_x = out_tensors_[0]->Width();
|
||||
}
|
||||
#ifdef ENABLE_FP16
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
#else
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
#endif
|
||||
img_size->clear();
|
||||
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
|
||||
*img_size = vec;
|
||||
return RET_OK;
|
||||
}
|
||||
int SliceOpenCLKernel::Init() {
|
||||
std::set<std::string> build_options;
|
||||
std::string source = slice_source;
|
||||
std::string program_name = "slice";
|
||||
std::string kernel_name = "slice";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
ori_format_ = out_tensors_[0]->GetFormat();
|
||||
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SliceOpenCLKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int SliceGetBiggestDividerWithPriority(int number, int max_divider) {
|
||||
if (number % 8 == 0 && 8 <= max_divider) {
|
||||
return number / 8;
|
||||
} else if (number % 4 == 0 && 4 <= max_divider) {
|
||||
return number / 4;
|
||||
} else if (number % 2 == 0 && 2 <= max_divider) {
|
||||
return number / 2;
|
||||
}
|
||||
for (int i = max_divider; i != 0; i--) {
|
||||
if (number % i == 0) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
void SlcieGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
|
||||
const int max_divider = 8;
|
||||
const int max_x = 4, max_y = 8;
|
||||
int x = std::min(SliceGetBiggestDividerWithPriority(global[0], max_divider), max_x);
|
||||
int yz = max_size / x;
|
||||
int y = std::min(std::min(SliceGetBiggestDividerWithPriority(global[1], max_divider), yz), max_y);
|
||||
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));
|
||||
|
||||
local->clear();
|
||||
local->push_back(x);
|
||||
local->push_back(y);
|
||||
local->push_back(z);
|
||||
}
|
||||
int SliceOpenCLKernel::Run() {
|
||||
MS_LOG(DEBUG) << this->name() << " Running!";
|
||||
auto param = reinterpret_cast<SliceParameter *>(this->op_parameter_);
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
auto input_shape = in_tensors_[0]->shape();
|
||||
cl_int4 input_shape_ = {input_shape[0], input_shape[1], input_shape[2], UP_DIV(input_shape[3], C4NUM)};
|
||||
cl_int4 size_ = {param->size_[0], param->size_[1], param->size_[2], UP_DIV(param->size_[3], C4NUM)};
|
||||
cl_int4 begin_ = {param->begin_[0], param->begin_[1], param->begin_[2], param->begin_[3] / 4};
|
||||
cl_int2 sharedNoUpdiv = {param->begin_[3], param->size_[3]};
|
||||
uint32_t OH = param->size_[1];
|
||||
uint32_t OW = param->size_[2];
|
||||
|
||||
const std::vector<size_t> &max_global = ocl_runtime->GetWorkItemSize();
|
||||
std::vector<size_t> local = {1, 1, 1}; // init local
|
||||
std::vector<size_t> global = {1, OH, OW};
|
||||
SlcieGetWorkGroup(global, &local, max_global[0]);
|
||||
int arg_cn = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->Data()); // input tensor
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); // out tensor
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape_);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_cn++, size_);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_cn++, begin_);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_cn++, sharedNoUpdiv);
|
||||
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
|
||||
|
||||
return RET_OK;
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
kernel::LiteKernel *OpenCLSliceKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *kernel = new (std::nothrow) SliceOpenCLKernel(opParameter, inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new SliceOpenCLKernel failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Slice, OpenCLSliceKernelCreator);
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_SLICE_H_
|
||||
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SLICE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp32/slice.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class SliceOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
explicit SliceOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
|
||||
~SliceOpenCLKernel() override{};
|
||||
|
||||
int Init() override;
|
||||
|
||||
int ReSize() override;
|
||||
|
||||
int Run() override;
|
||||
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
#endif
|
|
@ -46,8 +46,10 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
|
|||
}
|
||||
if (mem_type == OpenCLMemType::IMG) {
|
||||
jv->set_in_tensors({});
|
||||
jv->SetInKernel({});
|
||||
} else {
|
||||
jv->set_out_tensors({});
|
||||
jv->SetOutKernel({});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -120,13 +122,21 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
|
|||
if (mem_type == OpenCLMemType::IMG) {
|
||||
for (auto &iv : in_kernels[i]) {
|
||||
in_opencl_op->AddOutKernel(iv);
|
||||
reinterpret_cast<OpenCLKernel *>(iv)->SetInKernel({in_convert_op});
|
||||
reinterpret_cast<OpenCLKernel *>(iv)->set_in_tensors({new_tensor});
|
||||
auto kernels = iv->in_kernels();
|
||||
kernels.emplace_back(in_convert_op);
|
||||
iv->SetInKernel(kernels);
|
||||
auto tensors = iv->in_tensors();
|
||||
tensors.emplace_back(new_tensor);
|
||||
iv->set_in_tensors(tensors);
|
||||
}
|
||||
} else {
|
||||
for (auto &iv : in_kernels[i]) {
|
||||
reinterpret_cast<OpenCLKernel *>(iv)->SetOutKernel({in_convert_op});
|
||||
reinterpret_cast<OpenCLKernel *>(iv)->set_out_tensors({new_tensor});
|
||||
auto kernels = iv->out_kernels();
|
||||
kernels.emplace_back(in_convert_op);
|
||||
iv->SetOutKernel(kernels);
|
||||
auto tensors = iv->out_tensors();
|
||||
tensors.emplace_back(new_tensor);
|
||||
iv->set_out_tensors(tensors);
|
||||
in_convert_op->AddInKernel(iv);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -145,6 +145,7 @@ if (SUPPORT_GPU)
|
|||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/batchnorm.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/slice.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc
|
||||
|
@ -318,6 +319,7 @@ if (SUPPORT_GPU)
|
|||
${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/batchnorm_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/slice_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc
|
||||
|
|
|
@ -43,8 +43,8 @@ TEST_F(TestBatchnormOpenCL, Batchnorminput_dim4) {
|
|||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "Read tensors from .bin";
|
||||
std::vector<int> input_shape = {1, 256, 256, 48};
|
||||
std::vector<int> output_shape = {1, 256, 256, 48};
|
||||
std::vector<int> input_shape = {1, 256, 256, 16};
|
||||
std::vector<int> output_shape = {1, 256, 256, 16};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/slice.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestSliceOpenCL : public mindspore::CommonTest {
|
||||
public:
|
||||
TestSliceOpenCL() {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
T abs = fabs(output_data[i] - correct_data[i]);
|
||||
ASSERT_LE(abs, err_bound);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestSliceOpenCL, Sliceinput_dim4) {
|
||||
MS_LOG(INFO) << "begin test";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "Read tensors from .bin";
|
||||
std::vector<int> input_shape = {1, 256, 256, 48};
|
||||
std::vector<int> output_shape = {1, 255, 255, 15};
|
||||
std::vector<int> begin = {0, 1, 1, 7};
|
||||
std::vector<int> size = {1, 255, 255, 15};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
|
||||
// get the input from .bin
|
||||
size_t input_size, output_size;
|
||||
std::string input_path = "./test_data/in_data.bin";
|
||||
std::string output_path = "./test_data/out_data.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
|
||||
|
||||
MS_LOG(INFO) << "construct tensors";
|
||||
lite::tensor::Tensor *tensor_data =
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(INFO) << "init tensor failed";
|
||||
return;
|
||||
}
|
||||
auto *output_tensor =
|
||||
new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
delete tensor_data;
|
||||
MS_LOG(INFO) << "init tensor failed";
|
||||
return;
|
||||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs = {tensor_data};
|
||||
std::vector<lite::tensor::Tensor *> outputs = {output_tensor};
|
||||
|
||||
MS_LOG(INFO) << "setting SliceParameter";
|
||||
auto param = new (std::nothrow) SliceParameter();
|
||||
if (param == nullptr) {
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
MS_LOG(INFO) << "new SliceParameter failed";
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < 4; i++) {
|
||||
param->begin_[i] = begin[i];
|
||||
param->size_[i] = size[i];
|
||||
}
|
||||
|
||||
auto *slice_kernel =
|
||||
new (std::nothrow) kernel::SliceOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (slice_kernel == nullptr) {
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete param;
|
||||
MS_LOG(INFO) << "new kernel::slice_kernel failed";
|
||||
return;
|
||||
}
|
||||
slice_kernel->Init();
|
||||
|
||||
// to do allocate memory for inputs and outputs
|
||||
for (auto &input_tensor : inputs) {
|
||||
input_tensor->MallocData(allocator);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "initialize sub_graph";
|
||||
std::vector<kernel::LiteKernel *> kernels{slice_kernel};
|
||||
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
if (sub_graph == nullptr) {
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete param;
|
||||
delete slice_kernel;
|
||||
MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed";
|
||||
return;
|
||||
}
|
||||
sub_graph->Init();
|
||||
|
||||
MS_LOG(INFO) << "init tensors";
|
||||
memcpy(inputs[0]->Data(), input_data, input_size);
|
||||
|
||||
std::cout << "==================output data================" << std::endl;
|
||||
sub_graph->Run();
|
||||
|
||||
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data());
|
||||
CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete slice_kernel;
|
||||
delete sub_graph;
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue