diff --git a/build.sh b/build.sh index 2e67f1a247..ae67ea0c00 100755 --- a/build.sh +++ b/build.sh @@ -510,8 +510,12 @@ gene_ocl_program() { build_opencl() { cd ${BASEPATH} - git submodule update --init third_party/OpenCL-Headers - git submodule update --init third_party/OpenCL-CLHPP + if [[ ! -d "third_party/OpenCL-Headers" ]]; then + git submodule update --init third_party/OpenCL-Headers + fi + if [[ ! -d "third_party/OpenCL-CLHPP" ]]; then + git submodule update --init third_party/OpenCL-CLHPP + fi if [[ "${OPENCL_OFFLINE_COMPILE}" == "on" ]]; then gene_ocl_program else @@ -524,6 +528,7 @@ build_lite() echo "start build mindspore lite project" if [[ "${ENABLE_GPU}" == "on" ]]; then + echo "start build opencl" build_opencl fi if [[ "${LITE_PLATFORM}" == "x86_64" ]]; then diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 3e1a3515da..887da2c9de 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -49,7 +49,6 @@ endif () if (SUPPORT_GPU) add_definitions(-DUSE_OPENCL_WRAPPER) add_definitions(-DMS_OPENCL_PROFILE=false) - add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=200) add_compile_definitions(SUPPORT_GPU) if(OFFLINE_COMPILE) add_compile_definitions(PROGRAM_WITH_IL) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc index 12fdf20291..d731102818 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -19,7 +19,7 @@ #include "src/kernel_registry.h" #include "src/runtime/opencl/opencl_runtime.h" #include "src/runtime/kernel/opencl/kernel/concat.h" -#include "src/backend/opencl/cl/fp32/concat.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/concat.cl.inc" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; @@ -115,12 +115,16 @@ int GetBiggestDividerWithPriority(int number, int max_divider) { return 1; } -void ConcatGetWorkGroup(const std::vector &global, const std::vector &local, int max_size) { +void ConcatGetWorkGroup(const std::vector &global, std::vector *local, int max_size) { int x = std::min(GetBiggestDividerWithPriority(global[0], 8), 4); int yz = max_size / x; int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], 8), yz), 8); int z = std::min(yz / y, DivideRoundUp(global[2], 2)); - local = {static_cast(x), static_cast(y), static_cast(z)}; + + local->clear(); + local->push_back(x); + local->push_back(y); + local->push_back(z); } int ConcatOpenCLKernel::Run() { auto param = reinterpret_cast(this->opParameter); @@ -144,7 +148,7 @@ int ConcatOpenCLKernel::Run() { uint32_t OW = output_shape[2]; uint32_t OC = output_shape[3]; global = {OH, OW, OC}; // HWC - ConcatGetWorkGroup(global, local, 384); + ConcatGetWorkGroup(global, &local, 384); std::cout << "local size=:" << std::endl; for (int i = 0; i < local.size(); i++) { std::cout << local[i] << " "; @@ -174,7 +178,7 @@ int ConcatOpenCLKernel::Run() { uint32_t OW = output_shape[2]; uint32_t OC = output_shape[3]; global = {OH, OW, OC}; // HWC - ConcatGetWorkGroup(global, local, 384); + ConcatGetWorkGroup(global, &local, 384); std::cout << "local size=:" << std::endl; for (int i = 0; i < local.size(); i++) { std::cout << local[i] << " "; @@ -196,8 +200,9 @@ int ConcatOpenCLKernel::Run() { return 0; } -kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, +kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc) { auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); auto ret = kernel->Init(); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h index 0458b68c36..de0aac1a09 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h @@ -20,24 +20,21 @@ #include #include "ir/anf.h" #include "src/lite_kernel.h" -#include "src/backend/arm/opclib/conv_parameter.h" #include "src/runtime/opencl/opencl_runtime.h" -#include "src/backend/arm/opclib/concat.h" +#include "src/runtime/kernel/arm/base/concat_base.h" namespace mindspore::kernel { class ConcatOpenCLKernel : public LiteKernel { public: - explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) + explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} ~ConcatOpenCLKernel() override{}; int Init() override; - // int InferShape() { return {}; }; - int InferShape() {} int ReSize() override; int Run_axis0(); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h index 992c96d857..dac32964b5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h @@ -39,7 +39,6 @@ class Conv2dTransposeOpenCLKernel : public LiteKernel { ~Conv2dTransposeOpenCLKernel() override {}; int Init() override; - int InferShape() {} int ReSize() override; int Run() override; void PadWeight(); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h index a1c4485dd9..90be697103 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h @@ -41,7 +41,6 @@ class MatMulOpenCLKernel : public LiteKernel { ~MatMulOpenCLKernel() override{}; int Init() override; - int InferShape() {} int ReSize() override; int Run() override; void PadWeight(); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 892c96f900..ad0ef96195 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -265,10 +265,10 @@ endif() if (SUPPORT_GPU) set(TEST_SRC ${TEST_SRC} - ${TEST_DIR}/ut/stc/runtime/kernel/opencl/matmul_tests.cc - ${TEST_DIR}/ut/stc/runtime/kernel/opencl/depthwise_conv2d_tests.cc - ${TEST_DIR}/ut/stc/runtime/kernel/opencl/concat_tests.cc - ${TEST_DIR}/ut/stc/runtime/kernel/opencl/softmax_cl_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc ) endif() diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc similarity index 87% rename from mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_test.cc rename to mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index c7ea5cd47d..7c039d47ca 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -18,12 +18,10 @@ #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" -#include "mindspore/lite/src/backend/opencl/subgraph_opencl_kernel.h" -#include "mindspore/lite/src/backend/opencl/kernel/concat.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" + -using mindspore::kernel; -using mindspore::lite; -using mindspore; int DivideRoundUp(int n, int div) { int q = n / div; return n % div == 0 ? q : q + 1; @@ -96,7 +94,7 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i } namespace mindspore { -class TestConcatOpenCL : public UT::Common { +class TestConcatOpenCL : public mindspore::Common { public: TestConcatOpenCL(){} }; @@ -113,30 +111,31 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; auto data_type = kNumberTypeFloat32; auto tensor_type = schema::NodeType_ValueNode; - std::vector inputs; + std::vector inputs; for (auto &shape : input_shapes) { - inputs.push_back(new tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); + inputs.push_back(new lite::tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); } - auto *output_tensor = new tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); - std::vector outputs{output_tensor}; + auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); + std::vector outputs{output_tensor}; std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; MS_LOG(INFO) << "initialize tensors"; auto param = new ConcatParameter(); param->axis_ = 3; - auto *concat_kernel = new ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); + auto *concat_kernel = new kernel::ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); concat_kernel->Init(); MS_LOG(INFO) << "initialize sub_graph"; - std::vector kernels{concat_kernel}; - auto *sub_graph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + std::vector kernels{concat_kernel}; + auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); sub_graph->Init(); MS_LOG(INFO) << "initialize input data"; srand(time(NULL)); for (auto &input_tensor : inputs) { auto input_data = reinterpret_cast(input_tensor->Data()); + static unsigned int seed = 123; for (int i = 0; i < input_tensor->ElementsNum(); ++i) { - input_data[i] = static_cast(rand_r() % 10 + 1); + input_data[i] = static_cast(rand_r(&seed) % 10 + 1); } printf("\n"); } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc index cc54fd3c07..e7f8aa5209 100755 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -23,9 +23,6 @@ #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" -using mindspore::kernel; -using mindspore::lite; -using mindspore; #define SAFE_DELETE_ARRAY(a) \ if (a != nullptr) { \ @@ -39,12 +36,12 @@ using mindspore; } namespace mindspore { -class TestConvolutionDwOpenCL : public UT::Common { +class TestConvolutionDwOpenCL : public mindspore::Common { public: TestConvolutionDwOpenCL(){} }; -void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, +void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t *weight_data, float_t *gnd_data, schema::Format format, bool is_compare = true) { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); @@ -92,13 +89,13 @@ void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, flo inputs[1]->SetData(packed_weight); inputs[2]->SetData(bias_data); - OpParameter * parameter = conv_param; - auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + OpParameter * parameter = reinterpret_cast(conv_param); + auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); pKernel->Init(); - std::vector kernels{pKernel}; + std::vector kernels{pKernel}; std::vector inputs_{tensor_a}; - auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); pGraph->Init(); // freamework to do!!! @@ -141,7 +138,7 @@ void DepthWiseTestMain(const ConvParameter *conv_param, float_t *input_data, flo } std::cout << std::endl; // compare - UT::Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); SAFE_DELETE_ARRAY(packed_correct_data) } @@ -202,7 +199,7 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNC4HW4Fp32) { 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); - opencl::OpenCLRuntime::DeleteInstance(); + lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { @@ -275,7 +272,7 @@ TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) { 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4); - opencl::OpenCLRuntime::DeleteInstance(); + lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { @@ -321,7 +318,7 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) { 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); - opencl::OpenCLRuntime::DeleteInstance(); + lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { @@ -394,7 +391,7 @@ TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) { 1.0517888, 0.59817517, 0.75649744, 1.2075498, 0.38804203}; DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4); - opencl::OpenCLRuntime::DeleteInstance(); + lite::opencl::OpenCLRuntime::DeleteInstance(); } @@ -474,13 +471,13 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { inputs[1]->SetData(packed_weight); inputs[2]->SetData(bias_data); - OpParameter * parameter = conv_param; - auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + OpParameter * parameter = reinterpret_cast(conv_param); + auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); pKernel->Init(); - std::vector kernels{pKernel}; + std::vector kernels{pKernel}; std::vector inputs_{tensor_a}; - auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); pGraph->Init(); // freamework to do!!! @@ -517,7 +514,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { } std::cout << std::endl; // compare - CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); for (auto tensor : inputs) { tensor->SetData(nullptr); @@ -530,7 +527,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { SAFE_DELETE_PTR(pKernel) SAFE_DELETE_PTR(pGraph) MS_LOG(INFO) << "TestConvolutionDwNoPadFp32 passed"; - opencl::OpenCLRuntime::DeleteInstance(); + lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { @@ -637,13 +634,13 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { inputs[1]->SetData(packed_weight); inputs[2]->SetData(bias_data); - OpParameter * parameter = conv_param; - auto *pKernel = new DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); + OpParameter * parameter = reinterpret_cast(conv_param); + auto *pKernel = new kernel::DepthwiseConv2dOpenCLKernel(parameter, inputs, outputs); pKernel->Init(); - std::vector kernels{pKernel}; + std::vector kernels{pKernel}; std::vector inputs_{tensor_a}; - auto *pGraph = new SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); pGraph->Init(); // freamework to do!!! @@ -688,7 +685,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { } std::cout << std::endl; // compare - CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); + Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); SAFE_DELETE_ARRAY(packed_input); SAFE_DELETE_ARRAY(packed_correct_data) @@ -703,7 +700,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { SAFE_DELETE_PTR(pKernel) SAFE_DELETE_PTR(pGraph) MS_LOG(INFO) << "TestConvolutionDwPadFp32 passed"; - opencl::OpenCLRuntime::DeleteInstance(); + lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { @@ -803,7 +800,7 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { } SAFE_DELETE_ARRAY(input_data); SAFE_DELETE_ARRAY(weight_data); - opencl::OpenCLRuntime::DeleteInstance(); + lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc index d38a192390..e4f7e91769 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -22,10 +22,6 @@ #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h" -// using namespace mindspore::kernel; -// using namespace mindspore::lite; -// using namespace mindspore; - namespace mindspore { class TestMatMulOpenCL : public mindspore::Common { public: @@ -53,11 +49,11 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, co}); std::vector inputs{tensor_x, tensor_w}; std::vector outputs{tensor_out}; - auto *arith_kernel = new MatMulOpenCLKernel(nullptr, inputs, outputs, false); + auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false); arith_kernel->Init(); std::vector kernels{arith_kernel}; - auto *pGraph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); pGraph->Init(); memcpy(inputs[0]->Data(), input_data, sizeof(float) * ci); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc index eb891d567d..20bcfcefdb 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc @@ -22,10 +22,6 @@ #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" -// using namespace mindspore::kernel; -// using namespace mindspore::lite; -// using namespace mindspore; - namespace mindspore { class TestSoftmaxOpenCL : public mindspore::Common {}; @@ -53,12 +49,12 @@ TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) { std::vector outputs{tensor_out}; MS_LOG(INFO) << "create OpenCL Kernel"; - auto *Softmax_kernel = new SoftmaxOpenCLKernel(reinterpret_cast(param), inputs, outputs); + auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast(param), inputs, outputs); Softmax_kernel->Init(); std::vector kernels{Softmax_kernel}; MS_LOG(INFO) << "create SubGraphOpenCLKernel"; - auto *pGraph = new SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); pGraph->Init(); MS_LOG(INFO) << "initialize data";