From c515fcd36114906dcfaff3aa75d9d3087cbb2b4b Mon Sep 17 00:00:00 2001 From: chaijun Date: Mon, 19 Jul 2021 17:23:23 +0800 Subject: [PATCH] gpu demo --- .jenkins/check/config/filter_cpplint.txt | 3 + cmake/external_libs/opencl.cmake | 7 +- include/api/allocator.h | 9 + include/api/types.h | 5 + mindspore/lite/CMakeLists.txt | 3 +- .../runtime_extend/src/custom_common.h | 4 +- .../runtime_gpu_extend/CMakeLists.txt | 45 ++ .../lite/examples/runtime_gpu_extend/build.sh | 47 ++ .../lite/examples/runtime_gpu_extend/main.cc | 200 +++++++ .../runtime_gpu_extend/src/cl/arithmetic.cl | 17 + .../src/custom_add_infer.cc | 50 ++ .../src/custom_add_kernel_gpu.cc | 267 +++++++++ .../runtime_gpu_extend/src/custom_common.cc | 76 +++ .../runtime_gpu_extend/src/custom_common.h | 130 +++++ mindspore/lite/include/ms_tensor.h | 5 + .../include/registry/opencl_runtime_wrapper.h | 119 ++++ .../component/const_blocks/mtensor.cc | 1 + .../lite/src/cxx_api/tensor/tensor_impl.h | 7 + mindspore/lite/src/cxx_api/types.cc | 8 + mindspore/lite/src/inner_context.cc | 2 +- mindspore/lite/src/kernel_registry.cc | 3 + mindspore/lite/src/lite_kernel.h | 2 +- mindspore/lite/src/lite_session.cc | 4 +- mindspore/lite/src/lite_session.h | 2 +- .../lite/src/registry/register_kernel_impl.h | 1 + .../runtime/gpu/opencl/opencl_allocator.cc | 69 ++- .../src/runtime/gpu/opencl/opencl_allocator.h | 12 +- .../src/runtime/gpu/opencl/opencl_executor.cc | 86 +-- .../src/runtime/gpu/opencl/opencl_executor.h | 8 +- .../src/runtime/gpu/opencl/opencl_runtime.cc | 35 +- .../src/runtime/gpu/opencl/opencl_runtime.h | 76 +-- .../gpu/opencl/opencl_runtime_wrapper.cc | 155 +++++ .../runtime/kernel/opencl/kernel/argminmax.cc | 8 +- .../runtime/kernel/opencl/kernel/batchnorm.cc | 8 +- .../runtime/kernel/opencl/kernel/concat.cc | 7 +- .../runtime/kernel/opencl/kernel/conv2d.cc | 5 +- .../kernel/opencl/kernel/conv2d_transpose.cc | 2 +- .../kernel/opencl/kernel/depthwise_conv2d.cc | 5 +- .../src/runtime/kernel/opencl/kernel/fill.cc | 4 +- .../kernel/opencl/kernel/fullconnection.cc | 2 +- .../kernel/opencl/kernel/fusion_eltwise.cc | 3 +- .../runtime/kernel/opencl/kernel/gather.cc | 2 +- .../kernel/opencl/kernel/layer_norm.cc | 12 +- .../runtime/kernel/opencl/kernel/matmul.cc | 2 +- .../src/runtime/kernel/opencl/kernel/prelu.cc | 2 +- .../kernel/opencl/kernel/sparse_to_dense.cc | 7 +- .../src/runtime/kernel/opencl/kernel/split.cc | 9 +- .../src/runtime/kernel/opencl/kernel/stack.cc | 10 +- .../runtime/kernel/opencl/kernel/strassen.cc | 14 +- .../runtime/kernel/opencl/kernel/to_format.cc | 6 +- .../runtime/kernel/opencl/kernel/winograd.cc | 5 +- .../runtime/kernel/opencl/opencl_fusion.cc | 9 +- .../runtime/kernel/opencl/opencl_kernel.cc | 12 +- .../src/runtime/kernel/opencl/opencl_kernel.h | 6 +- .../runtime/kernel/opencl/opencl_subgraph.cc | 28 +- .../runtime/kernel/opencl/opencl_subgraph.h | 2 +- mindspore/lite/src/scheduler.cc | 3 + mindspore/lite/src/tensor.cc | 30 +- mindspore/lite/src/tensor.h | 2 +- mindspore/lite/test/CMakeLists.txt | 1 + mindspore/lite/test/config/ut_arm64.cfg | 1 + .../registry/registry_gpu_custom_op_test.cc | 530 ++++++++++++++++++ .../src/runtime/kernel/opencl/cast_tests.cc | 4 +- .../ut/src/runtime/kernel/opencl/common.cc | 4 +- .../src/runtime/kernel/opencl/fill_tests.cc | 4 +- 65 files changed, 1969 insertions(+), 238 deletions(-) create mode 100644 mindspore/lite/examples/runtime_gpu_extend/CMakeLists.txt create mode 100644 mindspore/lite/examples/runtime_gpu_extend/build.sh create mode 100644 mindspore/lite/examples/runtime_gpu_extend/main.cc create mode 100644 mindspore/lite/examples/runtime_gpu_extend/src/cl/arithmetic.cl create mode 100644 mindspore/lite/examples/runtime_gpu_extend/src/custom_add_infer.cc create mode 100644 mindspore/lite/examples/runtime_gpu_extend/src/custom_add_kernel_gpu.cc create mode 100644 mindspore/lite/examples/runtime_gpu_extend/src/custom_common.cc create mode 100644 mindspore/lite/examples/runtime_gpu_extend/src/custom_common.h create mode 100644 mindspore/lite/include/registry/opencl_runtime_wrapper.h create mode 100644 mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc create mode 100644 mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index c0f1a9e03df..9299ae5fc71 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -58,3 +58,6 @@ "mindspore/mindspore/lite/src/runtime/thread_pool.c" "runtime/arrays" "mindspore/mindspore/lite/src/runtime/thread_pool.c" "runtime/int" "mindspore/mindspore/lite/src/ops/ops_def.cc" "runtime/int" +"mindspore/mindspore/lite/examples/runtime_gpu_extend/src/cl" "legal/copyright" +"mindspore/mindspore/lite/examples/runtime_gpu_extend/src/cl" "readability/casting" +"mindspore/mindspore/lite/examples/runtime_gpu_extend/src/cl" "readability/fn_size" \ No newline at end of file diff --git a/cmake/external_libs/opencl.cmake b/cmake/external_libs/opencl.cmake index 01fa3c5fe1e..5301757c901 100644 --- a/cmake/external_libs/opencl.cmake +++ b/cmake/external_libs/opencl.cmake @@ -16,13 +16,12 @@ else() __download_pkg(OpenCL-CLHPP ${REQ_URL} ${MD5}) endif() -function(gene_opencl BASEPATH) - string(CONCAT CL_SRC_DIR "${BASEPATH}" "/src/runtime/kernel/opencl/cl") - message(STATUS "**********gene opencl*********base path: " "${BASEPATH}" ", cl path: " "${CL_SRC_DIR}") +function(gene_opencl CL_SRC_DIR) + message(STATUS "**********gene opencl********* cl path: " "${CL_SRC_DIR}") if(NOT EXISTS ${CL_SRC_DIR}) return() endif() - file(GLOB_RECURSE CL_LIST ${CL_SRC_DIR}/*.cl ${CL_SRC_DIR}/int8/*.cl) + file(GLOB_RECURSE CL_LIST ${CL_SRC_DIR}/*.cl) foreach(file_path ${CL_LIST}) file(REMOVE ${file_path}.inc) string(REGEX REPLACE ".+/(.+)\\..*" "\\1" kernel_name "${file_path}") diff --git a/include/api/allocator.h b/include/api/allocator.h index e78cf770b33..9c8a16a637a 100644 --- a/include/api/allocator.h +++ b/include/api/allocator.h @@ -32,6 +32,15 @@ class MS_API Allocator { /// \param[in] size Define the memory size to request. virtual void *Malloc(size_t size) = 0; + /// \brief Method to request memory. + /// + /// \param[in] weight Defines the width of memory to request + /// \param[in] height Defines the height of memory to request + /// \param[in] type Defines the data type of memory to request + virtual void *Malloc(size_t weight, size_t height, DataType type) { + return nullptr; + } + /// \brief Method to free memory. /// /// \param[in] ptr Define the pointer of a certain memory. diff --git a/include/api/types.h b/include/api/types.h index 702d79d142c..c652b9e5c4f 100644 --- a/include/api/types.h +++ b/include/api/types.h @@ -169,6 +169,11 @@ class MS_API MSTensor { /// \return The length of the data of the MSTensor, in bytes. size_t DataSize() const; + /// \brief Get whether the MSTensor data is const data + /// + /// \return Const flag of MSTensor + bool IsConst() const; + /// \brief Gets the boolean value that indicates whether the memory of MSTensor is on device. /// /// \return The boolean value that indicates whether the memory of MSTensor is on device. diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 7be2858ff7f..4f56795a9cf 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -358,7 +358,8 @@ if(MSLITE_ENABLE_FP16) endif() if(MSLITE_GPU_BACKEND STREQUAL opencl) add_definitions(-DGPU_OPENCL) - gene_opencl(${CMAKE_CURRENT_SOURCE_DIR}) + gene_opencl(${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/kernel/opencl/cl) + gene_opencl(${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/kernel/opencl/cl/int8) add_definitions(-DUSE_OPENCL_WRAPPER) add_definitions(-DMS_OPENCL_PROFILE=false) add_definitions(-DCL_TARGET_OPENCL_VERSION=200) diff --git a/mindspore/lite/examples/runtime_extend/src/custom_common.h b/mindspore/lite/examples/runtime_extend/src/custom_common.h index 57ef36e2c65..c784d796f4f 100644 --- a/mindspore/lite/examples/runtime_extend/src/custom_common.h +++ b/mindspore/lite/examples/runtime_extend/src/custom_common.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_EXAMPLES_RUNTIME_REGISTRY_SRC_CUSTOM_COMMON_H -#define MINDSPORE_LITE_EXAMPLES_RUNTIME_REGISTRY_SRC_CUSTOM_COMMON_H +#ifndef MINDSPORE_LITE_EXAMPLES_RUNTIME_EXTEND_SRC_CUSTOM_COMMON_H +#define MINDSPORE_LITE_EXAMPLES_RUNTIME_EXTEND_SRC_CUSTOM_COMMON_H #include #include "include/api/types.h" diff --git a/mindspore/lite/examples/runtime_gpu_extend/CMakeLists.txt b/mindspore/lite/examples/runtime_gpu_extend/CMakeLists.txt new file mode 100644 index 00000000000..5df86185f6e --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.14) +project(RuntimeGPUExtendTutorial) + +message(STATUS "Using toolchain file: ${CMAKE_TOOLCHAIN_FILE}.") + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) + message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") +endif() + +add_definitions(-DCL_TARGET_OPENCL_VERSION=200) +add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=120) +add_definitions(-DCL_HPP_MINIMUM_OPENCL_VERSION=120) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + +include(${CMAKE_CURRENT_SOURCE_DIR}/../../../../cmake/utils.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/../../../../cmake/external_libs/opencl.cmake) +gene_opencl(${CMAKE_CURRENT_SOURCE_DIR}/src/cl) + +# Add directory to include search path +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/include/third_party) +include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-headers-src/) +include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-clhpp-src/include) + +# Add directory to linker search path +link_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/lib) + +file(GLOB_RECURSE RUNTIME_REGISTRY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) + +add_executable(runtime_extend_tutorial ${RUNTIME_REGISTRY_SRC}) +target_link_libraries( + runtime_extend_tutorial + mindspore-lite + log +) + +add_executable(runtime_extend_tutorial_static ${RUNTIME_REGISTRY_SRC}) +target_link_libraries( + runtime_extend_tutorial_static + -Wl,--whole-archive libmindspore-lite.a -Wl,--no-whole-archive + log +) diff --git a/mindspore/lite/examples/runtime_gpu_extend/build.sh b/mindspore/lite/examples/runtime_gpu_extend/build.sh new file mode 100644 index 00000000000..b8eca256444 --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/build.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +BASEPATH=$(cd "$(dirname $0)" || exit; pwd) +get_version() { + VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]") + VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} +} +get_version +MODEL_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/add_extend.ms" +MODEL_DOWNLOAD_URL2="https://download.mindspore.cn/model_zoo/official/lite/quick_start/add.ms" +MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64" +MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz" +MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}" + +mkdir -p build +mkdir -p model +if [ ! -e ${BASEPATH}/model/add_extend.ms ]; then + wget -c -O ${BASEPATH}/model/add_extend.ms --no-check-certificate ${MODEL_DOWNLOAD_URL} +fi +if [ ! -e ${BASEPATH}/model/add.ms ]; then + wget -c -O ${BASEPATH}/model/add.ms --no-check-certificate ${MODEL_DOWNLOAD_URL2} +fi +if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then + wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL} +fi +tar -xzf ${BASEPATH}/build/${MINDSPORE_FILE} +cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime ${BASEPATH}/ +cd ${BASEPATH}/build || exit +cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_ABI="arm64-v8a" -DCMAKE_BUILD_TYPE="Release" ${BASEPATH} +make diff --git a/mindspore/lite/examples/runtime_gpu_extend/main.cc b/mindspore/lite/examples/runtime_gpu_extend/main.cc new file mode 100644 index 00000000000..bb5df63a1d2 --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/main.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "include/errorcode.h" +#include "include/context.h" +#include "include/api/types.h" +#include "include/api/model.h" + +namespace mindspore { +namespace lite { +namespace { +constexpr int kNumPrintOfOutData = 20; +std::string RealPath(const char *path) { + const size_t max = 4096; + if (path == nullptr) { + std::cerr << "path is nullptr" << std::endl; + return ""; + } + if ((strlen(path)) >= max) { + std::cerr << "path is too long" << std::endl; + return ""; + } + auto resolved_path = std::make_unique(max); + if (resolved_path == nullptr) { + std::cerr << "new resolved_path failed" << std::endl; + return ""; + } + + char *real_path = realpath(path, resolved_path.get()); + if (real_path == nullptr || strlen(real_path) == 0) { + std::cerr << "file path is not valid : " << path << std::endl; + return ""; + } + std::string res = resolved_path.get(); + return res; +} + +char *ReadFile(const char *file, size_t *size) { + if (file == nullptr) { + std::cerr << "file is nullptr." << std::endl; + return nullptr; + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cerr << "file: " << file << " is not exist." << std::endl; + return nullptr; + } + + if (!ifs.is_open()) { + std::cerr << "file: " << file << " open failed." << std::endl; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + std::unique_ptr buf(new (std::nothrow) char[*size]); + if (buf == nullptr) { + std::cerr << "malloc buf failed, file: " << file << std::endl; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf.get(), *size); + ifs.close(); + + return buf.release(); +} +} // namespace + +template +void GenerateRandomData(int size, void *data, Distribution distribution) { + std::mt19937 random_engine; + int elements_num = size / sizeof(T); + (void)std::generate_n(static_cast(data), elements_num, + [&distribution, &random_engine]() { return static_cast(distribution(random_engine)); }); +} + +void InitMSContext(const std::shared_ptr &context) { + context->SetThreadNum(1); + context->SetEnableParallel(false); + context->SetThreadAffinity(HIGHER_CPU); + auto &device_list = context->MutableDeviceInfo(); + + std::shared_ptr device_info = std::make_shared(); + device_info->SetEnableFP16(false); + device_list.push_back(device_info); + + std::shared_ptr provider_gpu_device_info = std::make_shared(); + provider_gpu_device_info->SetEnableFP16(false); + provider_gpu_device_info->SetProviderDevice("GPU"); + provider_gpu_device_info->SetProvider("Tutorial"); + device_list.push_back(provider_gpu_device_info); +} + +int CompileAndRun(int argc, const char **argv) { + if (argc < 2) { + std::cerr << "Model file must be provided.\n"; + return RET_ERROR; + } + // Read model file. + auto model_path = RealPath(argv[1]); + if (model_path.empty()) { + std::cerr << "model path " << argv[1] << " is invalid."; + return RET_ERROR; + } + + auto context = std::make_shared(); + if (context == nullptr) { + std::cerr << "New context failed." << std::endl; + return RET_ERROR; + } + + (void)InitMSContext(context); + + mindspore::Model ms_model; + size_t size = 0; + char *model_buf = ReadFile(model_path.c_str(), &size); + if (model_buf == nullptr) { + std::cerr << "Read model file failed." << std::endl; + return RET_ERROR; + } + auto ret = ms_model.Build(model_buf, size, kMindIR, context); + delete[](model_buf); + if (ret != kSuccess) { + std::cerr << "ms_model.Build failed." << std::endl; + return RET_ERROR; + } + std::vector ms_inputs_for_api = ms_model.GetInputs(); + for (auto tensor : ms_inputs_for_api) { + auto input_data = tensor.MutableData(); + if (input_data == nullptr) { + std::cerr << "MallocData for inTensor failed." << std::endl; + return RET_ERROR; + } + GenerateRandomData(tensor.DataSize(), input_data, std::uniform_real_distribution(1.0f, 1.0f)); + } + + std::cout << "\n------- print inputs ----------" << std::endl; + for (auto tensor : ms_inputs_for_api) { + std::cout << "in tensor name is:" << tensor.Name() << "\nin tensor size is:" << tensor.DataSize() + << "\nin tensor elements num is:" << tensor.ElementNum() << std::endl; + auto out_data = reinterpret_cast(tensor.MutableData()); + std::cout << "input data is:"; + for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) { + std::cout << out_data[i] << " "; + } + std::cout << std::endl; + } + std::cout << "------- print end ----------\n" << std::endl; + + std::vector outputs; + auto status = ms_model.Predict(ms_inputs_for_api, &outputs); + if (status != kSuccess) { + std::cerr << "Inference error." << std::endl; + return RET_ERROR; + } + + // Get Output Tensor Data. + auto out_tensors = ms_model.GetOutputs(); + std::cout << "\n------- print outputs ----------" << std::endl; + for (auto tensor : out_tensors) { + std::cout << "out tensor name is:" << tensor.Name() << "\nout tensor size is:" << tensor.DataSize() + << "\nout tensor elements num is:" << tensor.ElementNum() << std::endl; + auto out_data = reinterpret_cast(tensor.MutableData()); + std::cout << "output data is:"; + for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) { + std::cout << out_data[i] << " "; + } + std::cout << std::endl; + } + std::cout << "------- print end ----------\n" << std::endl; + return RET_OK; +} +} // namespace lite +} // namespace mindspore + +int main(int argc, const char **argv) { return mindspore::lite::CompileAndRun(argc, argv); } diff --git a/mindspore/lite/examples/runtime_gpu_extend/src/cl/arithmetic.cl b/mindspore/lite/examples/runtime_gpu_extend/src/cl/arithmetic.cl new file mode 100644 index 00000000000..0b34d4dab09 --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/src/cl/arithmetic.cl @@ -0,0 +1,17 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; + +__kernel void ElementAdd(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output, + const int2 output_shape) { + int X = get_global_id(0); + int Y = get_global_id(1); + if (X >= output_shape.x || Y >= output_shape.y) { + return; + } + + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y)); + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y)); + FLT4 result = a + b; + + WRITE_IMAGE(output, (int2)(X, Y), result); +} diff --git a/mindspore/lite/examples/runtime_gpu_extend/src/custom_add_infer.cc b/mindspore/lite/examples/runtime_gpu_extend/src/custom_add_infer.cc new file mode 100644 index 00000000000..43a435f1993 --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/src/custom_add_infer.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/custom_common.h" +#include "include/errorcode.h" +#include "include/registry/register_kernel_interface.h" + +namespace mindspore { +/** + * CustomAddInfer is a child class to infer current node output's information, including format, data_type and shape. + * if inputs' shape exist -1, don't worry, which shows that shape will be inferred when running. + */ +class CustomAddInfer : public kernel::KernelInterface { + public: + CustomAddInfer() = default; + ~CustomAddInfer() = default; + + Status Infer(std::vector *inputs, std::vector *outputs, + const schema::Primitive *primitive) override { + (*outputs)[0].SetFormat((*inputs)[0].format()); + (*outputs)[0].SetDataType((*inputs)[0].DataType()); + auto ret = custom_common::CheckInputs(*inputs); + if (ret != lite::RET_OK) { + if (ret == lite::RET_INFER_INVALID) { + (*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running. + return kLiteInferInvalid; + } else { + return kLiteError; + } + } + (*outputs)[0].SetShape((*inputs)[0].Shape()); + return kSuccess; + } +}; +std::shared_ptr CustomAddInferCreator() { return std::make_shared(); } +REGISTER_CUSTOM_KERNEL_INTERFACE(Tutorial, Custom_Add, CustomAddInferCreator) +} // namespace mindspore diff --git a/mindspore/lite/examples/runtime_gpu_extend/src/custom_add_kernel_gpu.cc b/mindspore/lite/examples/runtime_gpu_extend/src/custom_add_kernel_gpu.cc new file mode 100644 index 00000000000..b650494ba6b --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/src/custom_add_kernel_gpu.cc @@ -0,0 +1,267 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "src/custom_common.h" +#include "include/errorcode.h" +#include "include/registry/register_kernel_interface.h" +#include "include/registry/register_kernel.h" +#include "include/registry/opencl_runtime_wrapper.h" +#include "src/cl/arithmetic.cl.inc" +#include "include/api/data_type.h" +#include "include/schema/ops_generated.h" + +#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y)) + +namespace mindspore { +namespace custom_gpu_demo { + +class CustomAddKernel : public kernel::Kernel { + public: + CustomAddKernel(const std::vector &inputs, const std::vector &outputs, + const schema::Primitive *primitive, const mindspore::Context *ctx, const std::string &build_options, + bool fp16_enable) + : Kernel(inputs, outputs, primitive, ctx), build_options_(build_options), fp16_enable_(fp16_enable) { + opencl_runtime_ = new registry::opencl::OpenCLRuntimeWrapper(); + } + ~CustomAddKernel() override { FreeWeight(); } + // Prepare will be called during graph compilation + int Prepare() override { + const std::string kernel_name_ = "ElementAdd"; + const std::string program_name = "Arithmetic"; + std::string source = arithmetic_source; + if (opencl_runtime_->LoadSource(program_name, source) != kSuccess) { + std::cerr << "Load source failed."; + return lite::RET_ERROR; + } + std::vector build_options_ext = {"-cl-mad-enable -cl-fast-relaxed-math -Werror"}; + + build_options_ext.push_back(build_options_); + if (opencl_runtime_->BuildKernel(&kernel_, program_name, kernel_name_, build_options_ext) != kSuccess) { + std::cerr << "Build kernel failed."; + return lite::RET_ERROR; + } + + auto out_shape = custom_common::GpuTensorInfo(&outputs_[0], opencl_runtime_); + local_range_ = cl::NullRange; + global_range_ = cl::NDRange(out_shape.width, out_shape.height); + for (int i = 0; i < inputs_.size(); ++i) { + auto &in_tensor = inputs_.at(i); + custom_common::GpuTensorInfo in_shape = custom_common::GpuTensorInfo(&in_tensor, opencl_runtime_); + if (in_tensor.IsConst()) { + std::vector weight(in_shape.Image2DSize, 0); + bool src_is_fp16 = in_tensor.DataType() == mindspore::DataType::kNumberTypeFloat16; + PackNHWCToNHWC4(in_tensor.MutableData(), weight.data(), src_is_fp16, fp16_enable_, in_shape, + in_tensor.DataType()); + DataType dtype = + fp16_enable_ ? mindspore::DataType::kNumberTypeFloat16 : mindspore::DataType::kNumberTypeFloat32; + auto allocator = opencl_runtime_->GetAllocator(); + if (allocator == nullptr) { + std::cerr << "GetAllocator fail."; + FreeWeight(); + return lite::RET_ERROR; + } + auto weight_ptr = allocator->Malloc(in_shape.width, in_shape.height, dtype); + if (weight_ptr == nullptr) { + std::cerr << "Malloc fail."; + FreeWeight(); + return lite::RET_ERROR; + } + weight_ptrs_.push_back(weight_ptr); + // Use API to write GPU memory + if (opencl_runtime_->WriteImage(weight_ptr, weight.data()) != kSuccess) { + std::cerr << "WriteImage fail."; + FreeWeight(); + return lite::RET_ERROR; + } + } else { + weight_ptrs_.push_back(nullptr); + } + } + + int arg_idx = 3; + cl_int2 output_shape{static_cast(global_range_[0]), static_cast(global_range_[1])}; + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx, output_shape) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx << "failed."; + FreeWeight(); + return lite::RET_ERROR; + } + + std::cout << kernel_name_ << " Init Done!" << std::endl; + return lite::RET_OK; + } + + // Execute is called to compute. + int Execute() override { + if (inputs_.size() != 2) { + return lite::RET_PARAM_INVALID; + } + PreProcess(); + std::cout << this->name() << " Running!" << std::endl; + auto input_0_ptr = weight_ptrs_[0] == nullptr ? inputs_[0].MutableData() : weight_ptrs_[0]; + auto input_1_ptr = weight_ptrs_[1] == nullptr ? inputs_[1].MutableData() : weight_ptrs_[1]; + int arg_idx = 0; + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_0_ptr) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx - 1 << "failed."; + return lite::RET_ERROR; + } + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_1_ptr) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx - 1 << "failed."; + return lite::RET_ERROR; + } + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, outputs_[0].MutableData()) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx - 1 << "failed."; + return lite::RET_ERROR; + } + if (opencl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_) != kSuccess) { + std::cerr << "Run kernel failed."; + return lite::RET_ERROR; + } + + return lite::RET_OK; + } + + int CheckSpecs() { + for (auto &tensor : inputs_) { + if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) { + std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 input"; + return lite::RET_ERROR; + } + } + for (auto &tensor : outputs_) { + if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) { + std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 output"; + return lite::RET_ERROR; + } + } + + if (inputs_.size() != 2 || outputs_.size() != 1) { + std::cerr << "in size: " << inputs_.size() << ", out size: " << outputs_.size(); + return lite::RET_ERROR; + } + + return lite::RET_OK; + } + + // Resize is used to update some parameters if current node can change along with inputs. + int ReSize() override { + if (custom_common::CheckOutputs(outputs_) == lite::RET_OK) { + return lite::RET_OK; + } + auto status = + registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_); + if (status != kSuccess) { + std::cerr << "infer failed." << std::endl; + return lite::RET_ERROR; + } + auto ret = CheckSpecs(); + if (ret != lite::RET_OK) { + std::cerr << "ReSize failed for check kernel specs!"; + return ret; + } + ret = Prepare(); + if (ret != lite::RET_OK) { + std::cerr << "ReSize failed for kernel prepare!"; + return ret; + } + return lite::RET_OK; + } + + private: + std::string build_options_; + bool fp16_enable_; + cl::Kernel kernel_; + cl::Event event_; + cl::NDRange global_range_{cl::NullRange}; + cl::NDRange local_range_{cl::NullRange}; + std::vector weight_ptrs_; + registry::opencl::OpenCLRuntimeWrapper *opencl_runtime_; + + int PreProcess() { + int ret; + ret = ReSize(); + if (ret != lite::RET_OK) { + return ret; + } + for (auto i = 0; i < outputs_.size(); ++i) { + auto *output = &outputs_.at(i); + auto img_info = custom_common::GpuTensorInfo(output, opencl_runtime_); + auto allocator = output->allocator(); + if (allocator == nullptr) { + std::cerr << "The output tensor of OpenCL kernel must have an allocator."; + return lite::RET_ERROR; + } + auto data_ptr = allocator->Malloc(img_info.width, img_info.height, output->DataType()); + if (data_ptr == nullptr) { + std::cerr << "Malloc data failed"; + return lite::RET_ERROR; + } + output->SetData(data_ptr); + } + return lite::RET_OK; + } + + void FreeWeight() { + auto allocator = opencl_runtime_->GetAllocator(); + if (allocator == nullptr) { + std::cerr << "GetAllocator fail."; + return; + } + for (auto &weight_ptr : weight_ptrs_) { + if (weight_ptr != nullptr) { + allocator->Free(weight_ptr); + weight_ptr = nullptr; + } + } + } +}; + +std::shared_ptr CustomAddCreator(const std::vector &inputs, + const std::vector &outputs, + const schema::Primitive *primitive, const mindspore::Context *ctx) { + const std::string build_options = " -DFLT4=float4 -DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef "; + bool fp16_enable = false; + + std::cout << "using fp32 add.\n" << std::endl; + return std::make_shared(inputs, outputs, primitive, ctx, build_options, fp16_enable); +} + +std::shared_ptr CustomAddFP16Creator(const std::vector &inputs, + const std::vector &outputs, + const schema::Primitive *primitive, + const mindspore::Context *ctx) { + const std::string build_options = " -DFLT4=half4 -DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh"; + bool fp16_enable = true; + + std::cout << "using fp16 add." << std::endl; + return std::make_shared(inputs, outputs, primitive, ctx, build_options, fp16_enable); +} + +} // namespace custom_gpu_demo +const auto kFloat32 = DataType::kNumberTypeFloat32; +const auto kFloat16 = DataType::kNumberTypeFloat16; +// Register custom “Custom_Add” operator +REGISTER_CUSTOM_KERNEL(GPU, Tutorial, kFloat32, Custom_Add, custom_gpu_demo::CustomAddCreator) +REGISTER_CUSTOM_KERNEL(GPU, Tutorial, kFloat16, Custom_Add, custom_gpu_demo::CustomAddFP16Creator) +using schema::PrimitiveType_AddFusion; +// Register the add operator to replace the internal add operator of MindSpore Lite +REGISTER_KERNEL(GPU, Tutorial, kFloat32, PrimitiveType_AddFusion, custom_gpu_demo::CustomAddCreator) +REGISTER_KERNEL(GPU, Tutorial, kFloat16, PrimitiveType_AddFusion, custom_gpu_demo::CustomAddFP16Creator) +} // namespace mindspore diff --git a/mindspore/lite/examples/runtime_gpu_extend/src/custom_common.cc b/mindspore/lite/examples/runtime_gpu_extend/src/custom_common.cc new file mode 100644 index 00000000000..8201575d74b --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/src/custom_common.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/custom_common.h" + +namespace mindspore { +namespace custom_common { +int CheckInputs(const std::vector &inputs) { + for (auto &input : inputs) { + auto input_shape = input.Shape(); + if (std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end()) { + return lite::RET_INFER_INVALID; + } + } + return lite::RET_OK; +} + +int CheckOutputs(const std::vector &outputs) { + for (auto &output : outputs) { + auto output_shape = output.Shape(); + if (std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) { + return lite::RET_INFER_INVALID; + } + } + return lite::RET_OK; +} + +void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor, + mindspore::DataType data_type) { + auto src_fp16 = reinterpret_cast(src); + auto src_fp32 = reinterpret_cast(src); + auto src_int32 = reinterpret_cast(src); + auto dst_fp16 = reinterpret_cast(dst); + auto dst_fp32 = reinterpret_cast(dst); + auto dst_int32 = reinterpret_cast(dst); + for (int n = 0, src_idx = 0; n < tensor.N; n++) { + for (int h = 0; h < tensor.H; ++h) { + for (int w = 0; w < tensor.W; ++w) { + for (int c = 0; c < tensor.C; ++c, ++src_idx) { + int dst_idx = ((n * tensor.H + h) * tensor.W + w) * tensor.Slice * C4NUM + c; + if (data_type == mindspore::DataType::kNumberTypeInt32) { + dst_int32[dst_idx] = src_int32[src_idx]; + } else if (dst_is_fp16) { + dst_fp16[dst_idx] = src_is_fp16 ? src_fp16[src_idx] : static_cast(src_fp32[src_idx]); + } else { + dst_fp32[dst_idx] = src_is_fp16 ? static_cast(src_fp16[src_idx]) : src_fp32[src_idx]; + } + } + } + } + } + // scalar + if (tensor.ElementsNum == 1) { + if (dst_is_fp16) { + dst_fp16[3] = dst_fp16[2] = dst_fp16[1] = dst_fp16[0]; + } else { + dst_fp32[3] = dst_fp32[2] = dst_fp32[1] = dst_fp32[0]; + } + } +} + +} // namespace custom_common +} // namespace mindspore diff --git a/mindspore/lite/examples/runtime_gpu_extend/src/custom_common.h b/mindspore/lite/examples/runtime_gpu_extend/src/custom_common.h new file mode 100644 index 00000000000..3a53ffa90b6 --- /dev/null +++ b/mindspore/lite/examples/runtime_gpu_extend/src/custom_common.h @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_EXAMPLES_RUNTIME_GPU_EXTEND_SRC_CUSTOM_COMMON_H +#define MINDSPORE_LITE_EXAMPLES_RUNTIME_GPU_EXTEND_SRC_CUSTOM_COMMON_H + +#include +#include +#include +#include "include/api/types.h" +#include "include/errorcode.h" +#include "include/ms_tensor.h" +#include "include/api/data_type.h" +#include "include/registry/opencl_runtime_wrapper.h" + +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define C4NUM 4 +namespace mindspore { +namespace custom_common { + +template +void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num) { + if (src == nullptr || src_num <= 0) { + return; + } + auto *N = dst; + auto *H = dst + 1; + auto *W = dst + 2; + auto *C = dst + 3; + if (src_num == 1) { // 1 1 1 C + *C = src[0]; + } else if (src_num == 2) { // N 1 1 C + *N = src[0]; + *C = src[1]; + } else if (src_num == 3) { // N 1 W C + *N = src[0]; + *W = src[1]; + *C = src[2]; + } else if (src_num == 4) { // N H W C + *N = src[0]; + *H = src[1]; + *W = src[2]; + *C = src[3]; + } else if (src_num > 4) { + std::cerr << "GPU doesn't support ndim>=" << src_num; + } +} + +template +void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_value) { + for (int i = 0; i < 4; ++i) { + dst[i] = default_value; + } + if (src == nullptr || src_num <= 0) { + return; + } + Broadcast2GpuShape(dst, src, src_num); +} +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define C4NUM 4 +struct GpuTensorInfo { + GpuTensorInfo() = default; + explicit GpuTensorInfo(const MSTensor *tensor, registry::opencl::OpenCLRuntimeWrapper *opencl_run) { + if (tensor == nullptr) { + return; + } + auto shape_ori = tensor->Shape(); + int64_t shape[4]; + Broadcast2GpuShape(shape, shape_ori.data(), shape_ori.size(), 1l); + N = shape[0]; + H = shape[1]; + W = shape[2]; + C = shape[3]; + Slice = UP_DIV(C, C4NUM); + if (tensor->DataType() == mindspore::DataType::kNumberTypeFloat16) { + FLT_size = sizeof(cl_half); + } else { + FLT_size = sizeof(cl_float); + } + FLT4_size = FLT_size * C4NUM; + if (W * Slice <= opencl_run->GetMaxImage2DWidth()) { + height = N * H; + width = W * Slice; + } else { + height = N * H * W; + width = Slice; + if (height > opencl_run->GetMaxImage2DHeight()) { + height = -1; + width = -1; + } + } + + ElementsNum = N * H * W * C; + Image2DSize = height * width * FLT4_size; + } + size_t N{1}; + size_t H{1}; + size_t W{1}; + size_t C{1}; + size_t Slice{}; + size_t width{}; + size_t height{}; + size_t FLT_size{4}; + size_t FLT4_size{16}; + size_t ElementsNum{}; + size_t Image2DSize{}; +}; +// verify that the inputs' shape is inferred successfully when inferring current node. +int CheckInputs(const std::vector &inputs); + +// versify that the outputs' shape is inferred successfully when running current node. +int CheckOutputs(const std::vector &inputs); +void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor, + mindspore::DataType data_type = mindspore::DataType::kNumberTypeFloat32); +} // namespace custom_common +} // namespace mindspore +#endif // MINDSPORE_LITE_EXAMPLES_RUNTIME_GPU_EXTEND_SRC_CUSTOM_COMMON_H diff --git a/mindspore/lite/include/ms_tensor.h b/mindspore/lite/include/ms_tensor.h index 28da1378ff6..3035422123b 100644 --- a/mindspore/lite/include/ms_tensor.h +++ b/mindspore/lite/include/ms_tensor.h @@ -123,6 +123,11 @@ class MS_API MSTensor { virtual Vector quant_params() const = 0; virtual void set_quant_params(Vector) = 0; + + /// \brief Get whether the MSTensor data is const data + /// + /// \return Const flag of MSTensor + virtual bool IsConst() const = 0; }; } // namespace tensor } // namespace mindspore diff --git a/mindspore/lite/include/registry/opencl_runtime_wrapper.h b/mindspore/lite/include/registry/opencl_runtime_wrapper.h new file mode 100644 index 00000000000..fdb00060e37 --- /dev/null +++ b/mindspore/lite/include/registry/opencl_runtime_wrapper.h @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H +#define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H +#include +#include +#include +#include +#include +#include +#include +#include "CL/cl2.hpp" +#include "include/api/allocator.h" +#include "include/api/status.h" + +namespace mindspore::registry::opencl { +class OpenCLRuntimeWrapper { + public: + OpenCLRuntimeWrapper() = default; + ~OpenCLRuntimeWrapper() = default; + + /// \brief Load the OpenCl source code and bind the program name. + /// + /// \param[in] program_name Define OpenCl source program name. + /// \param[in] source Define OpenCl source. + /// + /// \return Status as a status identification of loading code. + Status LoadSource(const std::string &program_name, const std::string &source); + + /// \brief Building OpenCL code. + /// + /// \param[in] kernel Used to return the compiled kernel + /// \param[in] program_name Define OpenCl source program name. + /// \param[in] kernel_name Define OpenCl source kernel name. + /// \param[in] build_options_ext Define OpenCl kernel build options. + /// + /// \return Status as a status identification of build Kernel + Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name, + const std::vector &build_options_ext = {}); + + /// \brief Set kernel argument + /// + /// \param[in] kernel Define OpenCl kernel. + /// \param[in] index Define OpenCl kernel argument index. + /// \param[in] value Define OpenCl kernel argument value pointer. + /// \param[in] mem_type Define OpenCl kernel argument value memory type. + /// + /// \return Status as a status identification of set kernel argument + Status SetKernelArg(const cl::Kernel &kernel, uint32_t index, void *const value); + + /// \brief Set kernel argument + /// + /// \param[in] kernel Define OpenCl kernel. + /// \param[in] index Define OpenCl kernel argument index. + /// \param[in] value Define OpenCl kernel argument value. + /// \param[in] mem_type Define OpenCl kernel argument value memory type. + /// + /// \return Status as a status identification of set kernel argument + template + typename std::enable_if::value, Status>::type SetKernelArg(const cl::Kernel &kernel, + uint32_t index, const T value) { + if (const_cast(kernel).setArg(index, value) != CL_SUCCESS) { + return kLiteError; + } else { + return kSuccess; + } + } + + /// \brief Run OpenCl kernel + /// + /// \param[in] kernel Define OpenCl kernel. + /// \param[in] global Define the number of work items + /// \param[in] local Define the number of work_items in a work_group + /// \param[in] command_queue Define the command queue + /// \param[in] event Define event of kernel run + /// + /// \return Status as a status identification of run OpenCl kernel + Status RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local, + cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr); + + /// \brief Synchronization command queue + /// + /// \return Status as a status identification of synchronization command queue + Status SyncCommandQueue(); + + void *MapBuffer(void *host_ptr, int flags, bool sync = true); + + Status UnmapBuffer(void *host_ptr); + + Status ReadImage(void *buffer, void *dst_data); + + Status WriteImage(void *buffer, void *src_data); + + std::shared_ptr GetAllocator(); + + uint64_t DeviceMaxWorkGroupSize(); + + uint64_t GetMaxImage2DWidth(); + + uint64_t GetMaxImage2DHeight(); + + uint64_t GetImagePitchAlignment(); +}; +} // namespace mindspore::registry::opencl +#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H diff --git a/mindspore/lite/micro/coder/generator/component/const_blocks/mtensor.cc b/mindspore/lite/micro/coder/generator/component/const_blocks/mtensor.cc index 31bf11496fc..cea8a272a5f 100644 --- a/mindspore/lite/micro/coder/generator/component/const_blocks/mtensor.cc +++ b/mindspore/lite/micro/coder/generator/component/const_blocks/mtensor.cc @@ -78,6 +78,7 @@ class MTensor : public mindspore::tensor::MSTensor { void set_data(void *data) override { data_ = data; } Vector quant_params() const override { return this->quant_params_; } void set_quant_params(const Vector quant_params) override { this->quant_params_ = quant_params; } + bool IsConst() const override {return this->data_ != nullptr;} private: String tensor_name_; diff --git a/mindspore/lite/src/cxx_api/tensor/tensor_impl.h b/mindspore/lite/src/cxx_api/tensor/tensor_impl.h index 39de87c31d2..6a4c3f625ab 100644 --- a/mindspore/lite/src/cxx_api/tensor/tensor_impl.h +++ b/mindspore/lite/src/cxx_api/tensor/tensor_impl.h @@ -181,6 +181,13 @@ class MSTensor::Impl { } return lite_tensor_->MutableData(); } + virtual bool IsConst() const { + if (lite_tensor_ == nullptr) { + MS_LOG(ERROR) << "Invalid tensor."; + return false; + } + return lite_tensor_->IsConst(); + } virtual size_t DataSize() const { if (lite_tensor_ == nullptr) { diff --git a/mindspore/lite/src/cxx_api/types.cc b/mindspore/lite/src/cxx_api/types.cc index aefc768293c..4344700de81 100644 --- a/mindspore/lite/src/cxx_api/types.cc +++ b/mindspore/lite/src/cxx_api/types.cc @@ -259,6 +259,14 @@ void *MSTensor::MutableData() { return impl_->MutableData(); } +bool MSTensor::IsConst() const { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Invalid tensor implement."; + return false; + } + return impl_->IsConst(); +} + size_t MSTensor::DataSize() const { if (impl_ == nullptr) { MS_LOG(ERROR) << "Invalid tensor implement."; diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index 9887f30d3e4..b252ab887ce 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -215,7 +215,7 @@ bool InnerContext::IsGpuFloat16Enabled() const { if (!IsGpuEnabled()) { return false; } - opencl::OpenCLRuntimeWrapper wrapper; + opencl::OpenCLRuntimeInnerWrapper wrapper; if (!wrapper.GetInstance()->GetFp16Enable()) { return false; } diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 43c2b477d8e..2eaa594aeae 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -47,6 +47,7 @@ namespace mindspore::lite { #ifndef CUSTOM_KERNEL_REGISTRY_CLIP namespace { const char *const kArchCPU = "CPU"; +const char *const kArchGPU = "GPU"; void KernelKeyToKernelDesc(const KernelKey &key, KernelDesc *desc) { MS_ASSERT(desc != nullptr); desc->data_type = static_cast(key.data_type); @@ -159,6 +160,8 @@ int KernelRegistry::GetCustomKernel(const std::vector &in_tensors, con kernel::KernelKey tmp_key = key; if (desc.arch == kArchCPU) { tmp_key.arch = kernel::kCPU; + } else if (desc.arch == kArchGPU) { + tmp_key.arch = kernel::kGPU; } else { tmp_key.arch = kernel::kCustom; } diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 8385209e572..405e829c1ff 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -133,7 +133,7 @@ class LiteKernel { } return mindspore::lite::RET_OK; } - + bool IsBuiltin() { return desc_.provider == kBuiltin; } virtual int ReSize() { MS_ASSERT(kernel_ != nullptr); return kernel_->ReSize(); diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index d18464a805a..1e5c204f2ef 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -962,9 +962,9 @@ int LiteSession::InitGPURuntime() { } #if GPU_OPENCL if (this->context_->IsGpuEnabled()) { - opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper(); + opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeInnerWrapper(); if (opencl_runtime_wrapper_ == nullptr) { - MS_LOG(ERROR) << "create OpenCLRuntimeWrapper failed"; + MS_LOG(ERROR) << "create OpenCLRuntimeInnerWrapper failed"; return RET_ERROR; } auto gpu_device_info = this->context_->GetGpuInfo(); diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 5f875bc87ce..bce45f41d1a 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -155,7 +155,7 @@ class LiteSession : public session::LiteSession { bool is_train_session_ = false; friend class TransferSession; #if GPU_OPENCL - opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr}; + opencl::OpenCLRuntimeInnerWrapper *opencl_runtime_wrapper_{nullptr}; #endif std::unique_ptr sched_cb_; std::shared_ptr delegate_ = nullptr; diff --git a/mindspore/lite/src/registry/register_kernel_impl.h b/mindspore/lite/src/registry/register_kernel_impl.h index 37edb6f7421..2bcd8211a38 100644 --- a/mindspore/lite/src/registry/register_kernel_impl.h +++ b/mindspore/lite/src/registry/register_kernel_impl.h @@ -50,6 +50,7 @@ class RegistryKernelImpl { protected: std::map> kernel_creators_; + // keys:provider, arch, type std::map>> custom_kernel_creators_; diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.cc b/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.cc index 18cfbd73011..5bf8e1bc9ba 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.cc +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.cc @@ -94,8 +94,8 @@ void *OpenCLAllocator::CreateBuffer(size_t size, void *data, size_t flags, cl::B return host_ptr; } -void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map, - cl::Buffer **buffer, cl::Image2D **image) { +int OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map, + cl::Buffer **buffer, cl::Image2D **image, void **host_ptr) { cl_int ret = CL_SUCCESS; MS_ASSERT(buffer); MS_ASSERT(image); @@ -114,7 +114,7 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, voi delete *buffer; *buffer = nullptr; MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << mindspore::kernel::CLErrorCode(ret) << ")"; - return nullptr; + return RET_ERROR; } if (ret != CL_SUCCESS) { delete *buffer; @@ -122,28 +122,28 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, voi *buffer = nullptr; *image = nullptr; MS_LOG(ERROR) << "Create OpenCL Image2D (ERROR CODE: " << mindspore::kernel::CLErrorCode(ret) << ")"; - return nullptr; + return RET_ERROR; } MS_LOG(DEBUG) << "Malloc a new Image2D, width=" << img_size.width << ", height=" << img_size.height; - void *host_ptr = nullptr; + if (is_map) { std::vector region{img_size.width, img_size.height, 1}; - host_ptr = ocl_runtime_->MapBuffer(**image, true, CL_MAP_READ | CL_MAP_WRITE, region); - if (host_ptr == nullptr) { + *host_ptr = ocl_runtime_->MapBuffer(**image, true, CL_MAP_READ | CL_MAP_WRITE, region); + if (*host_ptr == nullptr) { delete *buffer; delete *image; *buffer = nullptr; *image = nullptr; - MS_LOG(ERROR) << "Map image failed, can not found image :" << *image << ", host_ptr=" << host_ptr; - return nullptr; + MS_LOG(ERROR) << "Map image failed, can not found image :" << *image << ", host_ptr=" << *host_ptr; + return RET_ERROR; } cl::Memory *mem = *image; - ret = ocl_runtime_->UnmapBuffer(*mem, host_ptr); + ret = ocl_runtime_->UnmapBuffer(*mem, *host_ptr); if (ret != CL_SUCCESS) { MS_LOG(WARNING) << "UnmapBuffer failed."; } } - return host_ptr; + return RET_OK; } int OpenCLAllocator::GetImgDtypeSize(const ImageSize &img_size) { @@ -165,6 +165,34 @@ int OpenCLAllocator::GetImgDtypeSize(const ImageSize &img_size) { return size; } +void *OpenCLAllocator::Malloc(size_t weight, size_t height, DataType type) { + ImageSize img_size = {weight, height}; + switch (type) { + case DataType::kNumberTypeFloat32: + img_size.dtype = CL_FLOAT; + break; + case DataType::kNumberTypeFloat16: + img_size.dtype = CL_HALF_FLOAT; + break; + case DataType::kNumberTypeInt8: + img_size.dtype = CL_SIGNED_INT8; + break; + case DataType::kNumberTypeUInt8: + img_size.dtype = CL_UNSIGNED_INT8; + break; + case DataType::kNumberTypeInt32: + img_size.dtype = CL_SIGNED_INT32; + break; + case DataType::kNumberTypeUInt32: + img_size.dtype = CL_UNSIGNED_INT32; + break; + default: + MS_LOG(ERROR) << "Unsupported type " << static_cast(type); + return nullptr; + } + return _Malloc(MemType::IMG, nullptr, 0, img_size); +} + void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const ImageSize &img_size) { auto svm_capabilities = ocl_runtime_->GetSVMCapabilities(); auto enable_arm_import_memory = ocl_runtime_->isExtensionEnable(EXT_ARM_IMPORT_MEMORY_HOST); @@ -208,9 +236,8 @@ void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const UNLOCK_AND_RETURN_NULL(host_ptr == nullptr, nullptr); } if (mem_type == MemType::IMG) { - void *host_ptr_im = CreateImage2D(size, img_size, data, flags, data != nullptr, &buffer, &image); - UNLOCK_AND_RETURN_NULL(data != nullptr && host_ptr_im == nullptr, nullptr); - host_ptr = (data != nullptr) ? host_ptr_im : host_ptr; + auto ret = CreateImage2D(size, img_size, data, flags, data != nullptr, &buffer, &image, &host_ptr); + UNLOCK_AND_RETURN_NULL(ret != RET_OK, nullptr); } } } @@ -345,17 +372,25 @@ size_t OpenCLAllocator::total_size() { return totalSize; } -void *OpenCLAllocator::GetImage(void *buffer) { +cl::Image2D *OpenCLAllocator::GetImage(void *buffer) { auto it = allocated_list_.find(buffer); if (it != allocated_list_.end()) { - return it->second->image_ptr_; + if (it->second->mem_type_ != MemType::IMG) { + return nullptr; + } + return reinterpret_cast(it->second->image_ptr_); } return nullptr; } -void *OpenCLAllocator::GetBuffer(void *buffer) { +void *OpenCLAllocator::GetOpenclMemPtr(void *buffer, MemType *type, bool force_buffer) { auto it = allocated_list_.find(buffer); if (it != allocated_list_.end()) { + if ((it->second->mem_type_ == MemType::IMG) && !force_buffer) { + *type = MemType::IMG; + return it->second->image_ptr_; + } + *type = MemType::BUF; return it->second->device_ptr_; } return nullptr; diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.h b/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.h index 2f061932c3f..3363192c279 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.h +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_allocator.h @@ -28,6 +28,8 @@ #include "CL/cl2.hpp" namespace mindspore::lite::opencl { +// OpenCL memory type, SHARED only valid on Mali devices. +enum class MemType : char { BUF, IMG, SHARED }; #define UNLOCK_AND_RETURN_NULL(condition, ptr) \ do { \ if (condition) { \ @@ -37,7 +39,6 @@ namespace mindspore::lite::opencl { } while (0) class OpenCLRuntime; -enum class MemType : char { BUF, IMG, SHARED }; struct ImageSize { size_t width = 0; @@ -57,6 +58,7 @@ class OpenCLAllocator : public mindspore::Allocator { // malloc shared void *Malloc(size_t size) override { return _Malloc(MemType::SHARED, nullptr, size); } + void *Malloc(size_t weight, size_t height, DataType type) override; // malloc buffer void *Malloc(size_t size, void *data) { return _Malloc(MemType::BUF, data, size); } // malloc image @@ -69,8 +71,8 @@ class OpenCLAllocator : public mindspore::Allocator { size_t total_size(); void Clear(); - void *GetImage(void *host_ptr); - void *GetBuffer(void *host_ptr); + cl::Image2D *GetImage(void *host_ptr); + void *GetOpenclMemPtr(void *buffer, MemType *type, bool force_buffer = false); void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true); int UnmapBuffer(void *host_ptr, void *command_queue = nullptr); MemType GetMemType(void *host_ptr); @@ -88,8 +90,8 @@ class OpenCLAllocator : public mindspore::Allocator { void *MinimumFit(MemType mem_type, size_t size, const ImageSize &img_size); void *_Malloc(MemType mem_type, void *data, size_t size = 0, const ImageSize &img_size = ImageSize()); void *CreateBuffer(size_t size, void *data, size_t flags, cl::Buffer **buffer); - void *CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map, - cl::Buffer **buffer, cl::Image2D **image); + int CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map, cl::Buffer **buffer, + cl::Image2D **image, void **host_ptr); int GetImgDtypeSize(const ImageSize &img_size); template void ClearMemList(T *list); diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.cc index 6b9143866dc..9ee61514d5f 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.cc +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.cc @@ -23,6 +23,9 @@ namespace mindspore::lite::opencl { int OpenCLExecutor::Run(const std::vector &inputs, const std::vector &outputs, const std::vector &kernels, const KernelCallBack &before, const KernelCallBack &after) { + if (before != nullptr && after != nullptr) { + ocl_runtime_.GetInstance()->SetProfiling(true); + } return RunOrTune(inputs, outputs, kernels, before, after, false); } @@ -30,10 +33,7 @@ int OpenCLExecutor::RunOrTune(const std::vector &inputs, const std::ve const std::vector &kernels, const KernelCallBack &before, const KernelCallBack &after, bool is_tune) { int ret{RET_OK}; - auto opencl_runtime_ins = ocl_runtime.GetInstance(); - if (before != nullptr && after != nullptr) { - opencl_runtime_ins->SetProfiling(true); - } + auto opencl_runtime_ins = ocl_runtime_.GetInstance(); auto profiling_tmp = opencl_runtime_ins->isProfiling(); if (is_tune) { opencl_runtime_ins->SetProfiling(true); @@ -43,12 +43,10 @@ int OpenCLExecutor::RunOrTune(const std::vector &inputs, const std::ve GPUCallBackParam callbackParam; callbackParam.node_name = kernel->name(); callbackParam.node_type = kernel->type_str(); - if (before != nullptr) { - if (!before(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) { - MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name(); - } + if ((before != nullptr) && + !before(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) { + MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name(); } - auto *op_kernel = reinterpret_cast(kernel->kernel()); // Don't support ZeroShape for (auto tensor : kernel->out_tensors()) { for (size_t i = 0; i < tensor->shape().size(); i++) { @@ -58,38 +56,58 @@ int OpenCLExecutor::RunOrTune(const std::vector &inputs, const std::ve } } } - if (is_tune) { - ret = op_kernel->PreProcess(); - if (RET_OK != ret) { - MS_LOG(WARNING) << "PreProcess kernel failed, name: " << kernel->name() << " in tuning"; - opencl_runtime_ins->SetProfiling(profiling_tmp); - return RET_OK; - } - ret = op_kernel->Tune(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "tuning kernel failed, name: " << kernel->name(); - return ret; + if (kernel->IsBuiltin()) { + auto *op_kernel = reinterpret_cast(kernel->kernel()); + + if (is_tune) { + ret = Tune(op_kernel); + if (ret != RET_OK) { + opencl_runtime_ins->SetProfiling(profiling_tmp); + return RET_OK; + } + } else { + ret = kernel->Execute(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); + return ret; + } + if (profiling_tmp) { + auto execute_time = op_kernel->GetProfilingTimeMs(); + MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() + << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; + callbackParam.execute_time = execute_time; + } } } else { - ret = kernel->Execute(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); - return ret; - } - if (profiling_tmp) { - auto execute_time = op_kernel->GetProfilingTimeMs(); - MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() - << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; - callbackParam.execute_time = execute_time; + if (!is_tune) { + ret = kernel->Execute(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); + return ret; + } } } - if (after != nullptr) { - if (!after(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) { - MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name(); - } + + if ((after != nullptr) && + !after(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) { + MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name(); } } opencl_runtime_ins->SetProfiling(profiling_tmp); return ret; } + +int OpenCLExecutor::Tune(kernel::OpenCLKernel *op_kernel) { + auto ret = op_kernel->PreProcess(); + if (ret != RET_OK) { + MS_LOG(WARNING) << "PreProcess kernel failed, name: " << op_kernel->name() << " in tuning"; + return ret; + } + ret = op_kernel->Tune(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "tuning kernel failed, name: " << op_kernel->name(); + return ret; + } + return RET_OK; +} } // namespace mindspore::lite::opencl diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h b/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h index 4563b78abec..4d902a8946d 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h @@ -27,7 +27,7 @@ namespace mindspore::lite::opencl { class OpenCLExecutor : public Executor { public: - OpenCLExecutor() : Executor() { allocator_ = ocl_runtime.GetInstance()->GetAllocator().get(); } + OpenCLExecutor() : Executor() { allocator_ = ocl_runtime_.GetInstance()->GetAllocator().get(); } ~OpenCLExecutor() override = default; @@ -43,10 +43,10 @@ class OpenCLExecutor : public Executor { const std::vector &kernels, const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr, bool is_tune = false); - protected: - InnerContext *context = nullptr; + private: + int Tune(kernel::OpenCLKernel *op_kernel); OpenCLAllocator *allocator_ = nullptr; - OpenCLRuntimeWrapper ocl_runtime; + OpenCLRuntimeInnerWrapper ocl_runtime_; }; } // namespace mindspore::lite::opencl #endif diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.cc index 94150ff5468..248cfcbf541 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.cc +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.cc @@ -204,6 +204,9 @@ int OpenCLRuntime::InitQueue(std::vector *platforms) { 0}; context_ = new (std::nothrow) cl::Context(std::vector{*device_}, ctx_properties.data(), nullptr, nullptr, &ret); + if (context_ == nullptr || ret != CL_SUCCESS) { + context_ = new (std::nothrow) cl::Context(std::vector{*device_}, nullptr, nullptr, nullptr, &ret); + } #else context_ = new (std::nothrow) cl::Context(std::vector{*device_}, nullptr, nullptr, nullptr, &ret); #endif @@ -334,7 +337,7 @@ cl::Device *OpenCLRuntime::Device() { return device_; } uint64_t OpenCLRuntime::DeviceGlobalMemoryCacheSize() const { return global_memery_cachesize_; } -int OpenCLRuntime::DeviceMaxWorkGroupSize() const { return max_work_group_size_; } +uint64_t OpenCLRuntime::DeviceMaxWorkGroupSize() const { return max_work_group_size_; } uint32_t OpenCLRuntime::DeviceComputeUnits() const { return compute_units_; } @@ -382,18 +385,24 @@ bool OpenCLRuntime::SetFp16Enable(bool enable) { } int OpenCLRuntime::BuildKernel(const cl::Kernel &kernel, const std::string &program_name, - const std::string &kernel_name, const std::vector &build_options_ext) { - std::string build_option = default_build_option_; - if (fp16_enable_) { - build_option += - " -DFP16_ENABLE=1 -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 -DUINT4=ushort4" - " -DTO_FLT=convert_half -DTO_FLT4=convert_half4"; - } else { - build_option += - " -DFP16_ENABLE=0 -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4" - " -DTO_FLT=convert_float -DTO_FLT4=convert_float4"; + const std::string &kernel_name, const std::vector &build_options_ext, + const bool is_builtin) { + std::string build_option; + if (is_builtin) { + build_option = default_build_option_; + if (fp16_enable_) { + build_option += + " -DFP16_ENABLE=1 -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 " + "-DUINT4=ushort4" + " -DTO_FLT=convert_half -DTO_FLT4=convert_half4"; + } else { + build_option += + " -DFP16_ENABLE=0 -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 " + "-DUINT4=uint4" + " -DTO_FLT=convert_float -DTO_FLT4=convert_float4"; + } + build_option += " -DMAX_IMAGE2D_WIDTH=" + std::to_string(max_image2d_width_); } - build_option += " -DMAX_IMAGE2D_WIDTH=" + std::to_string(max_image2d_width_); build_option = std::accumulate(build_options_ext.begin(), build_options_ext.end(), build_option, [](const std::string &options, const std::string &option) { return options + " " + option; }); @@ -515,7 +524,7 @@ bool OpenCLRuntime::BuildProgram(const std::string &build_option, const cl::Prog int OpenCLRuntime::ReadOrWriteImage(void *buffer, void *data, bool is_read) { cl::CommandQueue *command_queue = profiling_ ? profiling_command_queue_ : default_command_queue_; - auto *image = reinterpret_cast(allocator_->GetImage(buffer)); + auto *image = allocator_->GetImage(buffer); if (image == nullptr) { MS_LOG(WARNING) << "Can't get Image2D for " << buffer; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.h b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.h index e3fa70cdf7f..dd7e0a3bd3f 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.h +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.h @@ -38,11 +38,12 @@ enum InitState { UnInit = 0, InitSuccess = 1, InitFailed = 2 }; struct GpuInfo { GpuType type = OTHER; }; +class OpenCLRuntimeInnerWrapper; class OpenCLRuntimeWrapper; class OpenCLRuntime { public: + friend OpenCLRuntimeInnerWrapper; friend OpenCLRuntimeWrapper; - ~OpenCLRuntime(); OpenCLRuntime(const OpenCLRuntime &) = delete; OpenCLRuntime &operator=(const OpenCLRuntime &) = delete; @@ -55,7 +56,7 @@ class OpenCLRuntime { std::shared_ptr GetAllocator() { return allocator_; } cl::CommandQueue *GetDefaultCommandQueue() { return profiling_ ? profiling_command_queue_ : default_command_queue_; } uint64_t DeviceGlobalMemoryCacheSize() const; - int DeviceMaxWorkGroupSize() const; + uint64_t DeviceMaxWorkGroupSize() const; uint32_t DeviceComputeUnits() const; uint32_t DeviceMaxFreq() const; uint64_t GetMaxWorkGroupSize(const cl::Kernel &kernel); @@ -76,50 +77,35 @@ class OpenCLRuntime { template typename std::enable_if::value, cl_int>::type SetKernelArg(const cl::Kernel &kernel, uint32_t index, const T value, - const MemType mem_type = MemType::IMG) { + bool force_buffer = false) { if (value == nullptr) { MS_LOG(ERROR) << "value is nullptr."; return CL_INVALID_VALUE; } - switch (mem_type) { - case MemType::BUF: { - auto svm_capabilities = GetSVMCapabilities(); - if (svm_capabilities) { - MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value; - return clSetKernelArgSVMPointer(kernel.get(), index, value); - } - cl::Buffer *buffer = reinterpret_cast(allocator_->GetBuffer(value)); - if (buffer == nullptr) { - MS_LOG(ERROR) << "buffer is nullptr."; - return CL_INVALID_VALUE; - } - MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << buffer << ", host_ptr: " << value; - return const_cast(kernel).setArg(index, *buffer); - } - case MemType::IMG: { - cl::Image2D *image = reinterpret_cast(allocator_->GetImage(value)); - if (image == nullptr) { - MS_LOG(WARNING) << "Can't get Image2D, try to use Buffer. Please confirm the buffer type."; - cl::Buffer *buffer = reinterpret_cast(allocator_->GetBuffer(value)); - if (buffer == nullptr) { - MS_LOG(ERROR) << "buffer is nullptr."; - return CL_INVALID_VALUE; - } - MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << buffer << ", host_ptr: " << value; - return const_cast(kernel).setArg(index, *buffer); - } - MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Image2D " << image << ", host_ptr: " << value; - return const_cast(kernel).setArg(index, *image); - } - default: - MS_LOG(ERROR) << "Unsupported opencl memory type: " << static_cast(mem_type); - return CL_IMAGE_FORMAT_NOT_SUPPORTED; + auto svm_capabilities = GetSVMCapabilities(); + if (svm_capabilities) { + MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value; + return clSetKernelArgSVMPointer(kernel.get(), index, value); + } + lite::opencl::MemType mem_type; + void *buffer = allocator_->GetOpenclMemPtr(value, &mem_type, force_buffer); + if (buffer == nullptr) { + MS_LOG(ERROR) << "buffer is nullptr."; + return CL_INVALID_VALUE; + } + MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL " + << (mem_type == lite::opencl::MemType::IMG ? "Image " : "Buffer ") << buffer + << ", host_ptr: " << value; + if (mem_type == lite::opencl::MemType::IMG) { + return const_cast(kernel).setArg(index, *reinterpret_cast(buffer)); + } else { + return const_cast(kernel).setArg(index, *reinterpret_cast(buffer)); } } template - typename std::enable_if::value, cl_int>::type SetKernelArg( - const cl::Kernel &kernel, uint32_t index, const T value, const MemType mem_type = MemType::IMG) { + typename std::enable_if::value, cl_int>::type SetKernelArg(const cl::Kernel &kernel, + uint32_t index, const T value) { return const_cast(kernel).setArg(index, value); } @@ -129,7 +115,7 @@ class OpenCLRuntime { std::vector GetProgramBinary(const cl::Program &program); bool LoadSource(const std::string &program_name, const std::string &source); int BuildKernel(const cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, - const std::vector &build_options_ext = {}); + const std::vector &build_options_ext = {}, const bool is_builtin = true); int RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local, cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr); int ReadOrWriteImage(void *buffer, void *data, bool is_read); @@ -192,7 +178,7 @@ class OpenCLRuntime { uint64_t max_alloc_size_{0}; uint64_t max_image2d_width_{0}; uint64_t max_image2d_height_{0}; - int max_work_group_size_{1}; + uint64_t max_work_group_size_{1}; uint32_t compute_units_{0}; uint32_t max_freq_{0}; std::string default_build_option_{"-cl-mad-enable -cl-fast-relaxed-math -Werror"}; @@ -226,12 +212,12 @@ class OpenCLRuntime { const std::string cache_version_{"V0.1"}; }; -class OpenCLRuntimeWrapper { +class OpenCLRuntimeInnerWrapper { public: - OpenCLRuntimeWrapper() { ocl_runtime_ = OpenCLRuntime::GetInstance(); } - ~OpenCLRuntimeWrapper() { OpenCLRuntime::DeleteInstance(); } - OpenCLRuntimeWrapper(const OpenCLRuntimeWrapper &) = delete; - OpenCLRuntimeWrapper &operator=(const OpenCLRuntimeWrapper &) = delete; + OpenCLRuntimeInnerWrapper() { ocl_runtime_ = OpenCLRuntime::GetInstance(); } + ~OpenCLRuntimeInnerWrapper() { OpenCLRuntime::DeleteInstance(); } + OpenCLRuntimeInnerWrapper(const OpenCLRuntimeInnerWrapper &) = delete; + OpenCLRuntimeInnerWrapper &operator=(const OpenCLRuntimeInnerWrapper &) = delete; OpenCLRuntime *GetInstance() { return ocl_runtime_; } private: diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc new file mode 100644 index 00000000000..661315133fb --- /dev/null +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/registry/opencl_runtime_wrapper.h" +#include +#ifdef SHARING_MEM_WITH_OPENGL +#include +#endif +#include +#include +#include +#include "include/errorcode.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/gpu/opencl/opencl_allocator.h" +#include "src/common/file_utils.h" +#include "src/runtime/gpu/opencl/opencl_runtime.h" + +using mindspore::kernel::CLErrorCode; + +namespace mindspore::registry::opencl { + +Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + const std::string program_name_ext = "provider_" + program_name; + if (ocl_runtime->LoadSource(program_name_ext, source)) { + return kSuccess; + } else { + return kLiteError; + } +} + +Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name, + const std::string &kernel_name, + const std::vector &build_options_ext) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + const std::string program_name_ext = "provider_" + program_name; + if (ocl_runtime->BuildKernel(*kernel, program_name_ext, kernel_name, build_options_ext, false) == RET_OK) { + return kSuccess; + } else { + return kLiteError; + } +} + +Status OpenCLRuntimeWrapper::SetKernelArg(const cl::Kernel &kernel, uint32_t index, void *const value) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + if (ocl_runtime->SetKernelArg(kernel, index, value) != CL_SUCCESS) { + return kLiteError; + } else { + return kSuccess; + } +} + +Status OpenCLRuntimeWrapper::RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local, + cl::CommandQueue *command_queue, cl::Event *event) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + if (ocl_runtime->RunKernel(kernel, global, local, command_queue, event) == RET_OK) { + return kSuccess; + } else { + return kLiteError; + } +} + +Status OpenCLRuntimeWrapper::SyncCommandQueue() { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + if (ocl_runtime->SyncCommandQueue()) { + return kSuccess; + } else { + return kLiteError; + } +} + +void *OpenCLRuntimeWrapper::MapBuffer(void *host_ptr, int flags, bool sync) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + return ocl_runtime->GetAllocator()->MapBuffer(host_ptr, flags, nullptr, sync); +} + +Status OpenCLRuntimeWrapper::UnmapBuffer(void *host_ptr) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + if (ocl_runtime->GetAllocator()->UnmapBuffer(host_ptr, nullptr) == RET_OK) { + return kSuccess; + } else { + return kLiteError; + } +} + +Status OpenCLRuntimeWrapper::ReadImage(void *buffer, void *dst_data) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + if (ocl_runtime->ReadImage(buffer, dst_data) == RET_OK) { + return kSuccess; + } else { + return kLiteError; + } +} + +Status OpenCLRuntimeWrapper::WriteImage(void *buffer, void *src_data) { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + if (ocl_runtime->WriteImage(buffer, src_data) == RET_OK) { + return kSuccess; + } else { + return kLiteError; + } +} + +std::shared_ptr OpenCLRuntimeWrapper::GetAllocator() { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + return ocl_runtime->GetAllocator(); +} + +uint64_t OpenCLRuntimeWrapper::DeviceMaxWorkGroupSize() { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + return ocl_runtime->DeviceMaxWorkGroupSize(); +} + +uint64_t OpenCLRuntimeWrapper::GetMaxImage2DWidth() { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + return ocl_runtime->GetMaxImage2DWidth(); +} + +uint64_t OpenCLRuntimeWrapper::GetMaxImage2DHeight() { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + return ocl_runtime->GetMaxImage2DHeight(); +} + +uint64_t OpenCLRuntimeWrapper::GetImagePitchAlignment() { + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; + lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); + return ocl_runtime->GetImagePitchAlignment(); +} +} // namespace mindspore::registry::opencl diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc index d9a324cdc13..4ad6ac35737 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc @@ -68,11 +68,11 @@ int ArgMinMaxOpenCLKernel::SetConstArgs() { static_cast(im_in_.C)}; cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_}; int arg_cnt = 2; - if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, ids_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, ids_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } @@ -228,11 +228,11 @@ int ArgMinMaxOpenCLKernel::Prepare() { int ArgMinMaxOpenCLKernel::Run() { MS_LOG(DEBUG) << this->name() << " Running! "; - if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c(), lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c(), true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c(), true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc index 8638c05d0e1..a299accdd0a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc @@ -266,19 +266,19 @@ int BatchNormOpenCLKernel::Run() { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // input tensor - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, scale_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, scale_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // scale - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, offset_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, offset_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // offset - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // mean - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, variance_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, variance_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // variance diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc index fb50e74c69b..5f93ff862c0 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -36,7 +36,7 @@ int ConcatOpenCLKernel::RunAxis0() { auto dst_data = out_tensors_[0]->data_c(); MS_ASSERT(dst_data); auto dst_origin = cl::array{0, 0, 0}; - auto *out_image = reinterpret_cast(allocator_->GetImage(dst_data)); + auto *out_image = allocator_->GetImage(dst_data); for (int i = 0; i < in_tensors_.size(); i++) { auto src_data = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : weight_ptrs_.at(i); if (allocator_->GetImageSize(src_data, &img_size) != RET_OK) { @@ -45,7 +45,7 @@ int ConcatOpenCLKernel::RunAxis0() { } auto src_origin = cl::array{0, 0, 0}; auto region = cl::array{img_size.width, img_size.height, 1}; - auto *input_image = reinterpret_cast(allocator_->GetImage(src_data)); + auto *input_image = allocator_->GetImage(src_data); if (ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region) != CL_SUCCESS) { MS_LOG(WARNING) << "enqueueCopyImage failed."; @@ -290,8 +290,7 @@ int ConcatOpenCLKernel::Run() { } } if (axis_ == 3 && !Align_) { - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) != - CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc index 9a5884e5a4d..9c4eea5bbc7 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc @@ -435,11 +435,12 @@ int Conv2DOpenCLKernel::SetConstArgs() { cl_int2 dilation = {param_->dilation_h_, param_->dilation_w_}; int arg_cn = 2; - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, filter_type_) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, (filter_type_ == lite::opencl::MemType::BUF)) != + CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_bias_, MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_bias_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 5e415c9036e..82cfd06244f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -119,7 +119,7 @@ int Conv2dTransposeOpenCLKernel::SetConstArgs() { cl_int2 padding = {pad_h, pad_w}; cl_int4 src_size = {h, w, UP_DIV(ci, C4NUM), n}; cl_int4 dst_size = {oh, ow, UP_DIV(co, C4NUM), n}; - if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, padWeight_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, padWeight_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 9cbea18808f..bab7a465056 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -231,11 +231,12 @@ int DepthwiseConv2dOpenCLKernel::SetConstArgs() { cl_int4 dst_size = {(cl_int)out_info.W, (cl_int)out_info.H, (cl_int)CO4, (cl_int)out_info.N}; int arg_cnt = 2; - if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, filter_type_) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, (filter_type_ == lite::opencl::MemType::BUF)) != + CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc index 7e318d83ba6..8f494f3503d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc @@ -43,7 +43,7 @@ int FillOpenCLKernel::RunFill() { } auto src_origin = cl::array{0, 0, 0}; auto region = cl::array{img_size.width, img_size.height, 1}; - cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(src_data)); + cl::Image2D *out_image = allocator_->GetImage(src_data); if (ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region) != CL_SUCCESS) { MS_LOG(ERROR) << "enqueueFillImage failed."; @@ -66,7 +66,7 @@ int FillOpenCLKernel::RunShape() { } auto src_origin = cl::array{0, 0, 0}; auto region = cl::array{1, 1, 1}; - cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(src_data)); + cl::Image2D *out_image = allocator_->GetImage(src_data); if (ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region) != CL_SUCCESS) { MS_LOG(ERROR) << "enqueueFillImage failed."; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc index 8bb4deebb73..6665bb36044 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc @@ -260,7 +260,7 @@ void FullConnectionOpenCLKernel::SetGlobalLocal() { int FullConnectionOpenCLKernel::SetConstArgs() { if (!weight_var_) { - if (ocl_runtime_->SetKernelArg(kernel_, 2, padWeight_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, 2, padWeight_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc index 4d17eba5093..c69f6d3763c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.cc @@ -288,8 +288,7 @@ int FusionEltwiseOpenCLKernel::SetConstArgs() { } } } else { - if (ocl_runtime_->SetKernelArg(kernel_, arg_idx, buffer_weights_[buffer_idx++], lite::opencl::MemType::BUF) != - CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_idx, buffer_weights_[buffer_idx++], true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc index 3f1bf1d76e7..7958108147a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc @@ -258,7 +258,7 @@ int GatherOpenCLKernel::Run() { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, 2, indices_data_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, 2, indices_data_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc index 77cf59eae40..fb5fdd2e8d5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc @@ -254,11 +254,11 @@ int LayerNormOpenCLKernel::Run() { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // input tensor - if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, mean_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, mean_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, var_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, var_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } @@ -273,19 +273,19 @@ int LayerNormOpenCLKernel::Run() { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // out tensor - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // mean_ - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, var_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, var_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // var_ - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, gamma_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, gamma_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // gamma_ - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, beta_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, beta_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } // beta_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index dc5b5b6cd51..2a17e055771 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -268,7 +268,7 @@ int MatMulOpenCLKernel::SetConstArgs() { if (act_weight_) { arg_count++; } else { - if (ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc index 218b71ddffe..294300dc807 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc @@ -184,7 +184,7 @@ int PReluOpenCLKernel::Run() { return RET_ERROR; } } else { - if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc index eda7fa0ce65..29ba4826958 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc @@ -44,7 +44,7 @@ int SparseToDenseOpenCLKernel::InitOutputToDefault() { } auto src_origin = cl::array{0, 0, 0}; auto region = cl::array{img_size.width, img_size.height, 1}; - cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(src_data)); + cl::Image2D *out_image = allocator_->GetImage(src_data); if (ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region) != CL_SUCCESS) { MS_LOG(ERROR) << "enqueueFillImage failed."; @@ -267,13 +267,12 @@ int SparseToDenseOpenCLKernel::Run() { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) != - CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } if (!weight_scalar_) { - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_vector_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_vector_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/split.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/split.cc index f4f1974892c..32c5d238aab 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/split.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/split.cc @@ -33,7 +33,7 @@ int SplitOpenCLKernel::RunAxis0() { auto allocator_ = ocl_runtime_->GetAllocator(); auto src_data = in_tensors_[0]->data_c(); CHECK_NULL_RETURN(src_data); - cl::Image2D *in_image = reinterpret_cast(allocator_->GetImage(src_data)); + cl::Image2D *in_image = allocator_->GetImage(src_data); if (in_image == nullptr) { MS_LOG(ERROR) << "RunAxis0 in_image can not be nullptr"; return RET_ERROR; @@ -49,7 +49,7 @@ int SplitOpenCLKernel::RunAxis0() { } auto dst_area = cl::array{0, 0, 0}; auto region = cl::array{img_size.width, img_size.height, 1}; - cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(dst_data)); + cl::Image2D *out_image = allocator_->GetImage(dst_data); if (out_image == nullptr) { MS_LOG(ERROR) << "RunAxis0 out_image can not be nullptr"; return RET_ERROR; @@ -252,8 +252,7 @@ int SplitOpenCLKernel::Run() { return RET_ERROR; } } else { - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c(), lite::opencl::MemType::BUF) != - CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c(), true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } @@ -264,7 +263,7 @@ int SplitOpenCLKernel::Run() { return RET_ERROR; } } - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, split_sizes_, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, split_sizes_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc index 2302c2f4156..839c7875f1e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc @@ -34,7 +34,7 @@ int StackOpenCLKernel::RunAxis0() { auto dst_data = out_tensors_[0]->data_c(); MS_ASSERT(dst_data); auto dst_origin = cl::array{0, 0, 0}; - cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(dst_data)); + cl::Image2D *out_image = allocator_->GetImage(dst_data); for (int i = 0; i < in_tensors_.size(); i++) { auto src_data = in_tensors_[i]->data_c(); MS_ASSERT(src_data); @@ -44,7 +44,7 @@ int StackOpenCLKernel::RunAxis0() { } auto src_origin = cl::array{0, 0, 0}; auto region = cl::array{img_size.width, img_size.height, 1}; - cl::Image2D *input_image = reinterpret_cast(allocator_->GetImage(src_data)); + cl::Image2D *input_image = allocator_->GetImage(src_data); if (ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region) != CL_SUCCESS) { MS_LOG(WARNING) << "enqueueCopyImage failed."; @@ -209,14 +209,12 @@ int StackOpenCLKernel::Run() { int arg_cn = 0; if (buffer_button_) { for (int i = 0; i < in_tensors_.size(); ++i) { - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c(), lite::opencl::MemType::BUF) != - CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c(), true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } } - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) != - CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc index be61ca7b6f3..48a11477fd6 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc @@ -249,11 +249,11 @@ int StrassenOpenCLKernel::StrassenDataFilled(cl::Kernel *kernel, void *input, vo return RET_ERROR; } } else { - if (ocl_runtime_->SetKernelArg(*kernel, 0, input, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(*kernel, 0, input, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(*kernel, 1, output, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(*kernel, 1, output, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } @@ -277,20 +277,20 @@ int StrassenOpenCLKernel::StrassenAddSub(cl::Kernel *kernel, void *input, void * return RET_ERROR; } if (mem_type == lite::opencl::MemType::IMG) { - if (ocl_runtime_->SetKernelArg(*kernel, 0, input, lite::opencl::MemType::IMG) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(*kernel, 0, input) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(*kernel, 1, output, lite::opencl::MemType::IMG) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(*kernel, 1, output) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } } else { - if (ocl_runtime_->SetKernelArg(*kernel, 0, input, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(*kernel, 0, input, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(*kernel, 1, output, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(*kernel, 1, output, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } @@ -371,7 +371,7 @@ int StrassenOpenCLKernel::StrassenRunMmatmul(void *input, void *weight, void *ou MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, 2, weight, lite::opencl::MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, 2, weight, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc index 0d6ff88d36d..37de08489df 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc @@ -108,11 +108,13 @@ int ToFormatOpenCLKernel::Run() { MS_LOG(DEBUG) << this->name() << " Running!"; auto src_mem_type = (out_mem_type_ == MemType::IMG) ? lite::opencl::MemType::BUF : lite::opencl::MemType::IMG; auto dst_mem_type = out_mem_type_; - if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_.front()->data_c(), src_mem_type) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_.front()->data_c(), + (src_mem_type == lite::opencl::MemType::BUF)) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_.front()->data_c(), dst_mem_type) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_.front()->data_c(), + (dst_mem_type == lite::opencl::MemType::BUF)) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.cc index b189213693e..4438bdb5864 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.cc @@ -240,7 +240,8 @@ int WinogradOpenCLKernel::SetConstArgs() { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, filter_type_) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, (filter_type_ == lite::opencl::MemType::BUF)) != + CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } @@ -263,7 +264,7 @@ int WinogradOpenCLKernel::SetConstArgs() { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } - if (ocl_runtime_->SetKernelArg(kernel_36to4x4_, arg_cn++, packed_bias_, MemType::BUF) != CL_SUCCESS) { + if (ocl_runtime_->SetKernelArg(kernel_36to4x4_, arg_cn++, packed_bias_, true) != CL_SUCCESS) { MS_LOG(ERROR) << "SetKernelArg failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc index ec4d603a732..a67012be960 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc @@ -59,7 +59,7 @@ inline bool PredIs(const LiteKernel *node, PrimitiveType type, std::vectorin_kernels().size() == 1) { LiteKernel *pred = node->in_kernels().front(); MS_ASSERT(pred); - if (AIsInB(pred, nodes) && pred->type() == type && pred->out_kernels().size() == 1) { + if (AIsInB(pred, nodes) && pred->type() == type && pred->out_kernels().size() == 1 && pred->IsBuiltin()) { MS_ASSERT(pred->out_kernels().front() == node); return true; } @@ -578,7 +578,7 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol // Eltwise + Eltwise int TryMergeEltwiseEltwise(LiteKernel *node, std::set *removed_set, std::vector *nodes) { - if (!node->InferShapeDone()) { + if (!node->InferShapeDone() || !node->IsBuiltin()) { return RET_ERROR; } MS_ASSERT(node); @@ -598,6 +598,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set *removed_set if (!pred->InferShapeDone()) { continue; } + if (!pred->IsBuiltin()) { + return RET_ERROR; + } if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) { auto *tensor = pred->out_tensors().front(); MS_ASSERT(pred->out_kernels().front() == node); @@ -627,7 +630,7 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set *removed_set } void DoSpecificFusion(LiteKernel *node, std::set *removed_set, std::vector *nodes) { - if (!node->InferShapeDone()) { + if (!node->InferShapeDone() || !node->IsBuiltin()) { return; } switch (node->type()) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc index bdab2eb6599..2aa7c53c3e8 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -105,7 +105,7 @@ void OpenCLKernel::PrintOutput(int print_num, const std::string &out_file) { GpuTensorInfo img_info(tensor); auto size = mem_type == lite::opencl::MemType::BUF ? img_info.OriginSize : img_info.Image2DSize; std::vector data(size); - auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper(); auto runtime = runtime_wrapper.GetInstance(); auto allocator = runtime->GetAllocator(); if (!runtime->SyncCommandQueue()) { @@ -158,10 +158,10 @@ int OpenCLKernel::PreProcess() { if (ret != RET_OK) { return ret; } - auto allocator = ocl_runtime_->GetAllocator(); for (auto i = 0; i < out_tensors_.size(); ++i) { auto *output = out_tensors_.at(i); - MS_ASSERT(output); + CHECK_NULL_RETURN(output); + CHECK_NULL_RETURN(output->allocator()); if (GetMemType() == lite::opencl::MemType::IMG) { ImageSize img_size; ret = GetImageSize(i, &img_size); @@ -169,20 +169,20 @@ int OpenCLKernel::PreProcess() { MS_LOG(ERROR) << "GetImageSize failed"; return ret; } - auto data_ptr = allocator->Malloc(img_size); + auto data_ptr = + output->allocator()->Malloc(img_size.width, img_size.height, static_cast(output->data_type())); if (data_ptr == nullptr) { MS_LOG(ERROR) << "Malloc data failed"; return RET_ERROR; } output->set_data(data_ptr); } else { - ret = output->MallocData(allocator); + ret = output->MallocData(); if (ret != RET_OK) { MS_LOG(ERROR) << "MallocData failed"; return ret; } } - output->set_allocator(allocator); output->ResetRefCount(); } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 17a7b09cf13..afeaa484b81 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -92,7 +92,7 @@ void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_va struct GpuTensorInfo { GpuTensorInfo() = default; explicit GpuTensorInfo(const lite::Tensor *tensor) { - auto ocl_runtime_wrap_ = lite::opencl::OpenCLRuntimeWrapper(); + auto ocl_runtime_wrap_ = lite::opencl::OpenCLRuntimeInnerWrapper(); if (tensor == nullptr) { return; } @@ -131,7 +131,7 @@ struct GpuTensorInfo { } size_t RowPitch() const { - auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper(); int alignment = runtime_wrapper.GetInstance()->GetImagePitchAlignment(); MS_ASSERT(alignment); size_t row_pitch = UP_ROUND(width, alignment) * FLT4_size; @@ -238,7 +238,7 @@ class OpenCLKernel : public InnerKernel { bool dequant_flag_{false}; private: - lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap_; static inline std::map tuned_param_cache_; }; template diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc index 2b3323dcd06..7c46e56771f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc @@ -316,16 +316,23 @@ int OpenCLSubGraph::Prepare() { MS_LOG(ERROR) << "node in Subgraph is nullptr"; return mindspore::lite::RET_NULL_PTR; } - auto opencl_kernel = reinterpret_cast(node->kernel()); - std::set pre_init_weight_list = {schema::PrimitiveType_MatMul, schema::PrimitiveType_BiasAdd}; - if (pre_init_weight_list.find(opencl_kernel->type()) != pre_init_weight_list.end()) { - auto ret = opencl_kernel->InitWeights(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "init weights " << node->name() << " failed"; - return ret; + for (const auto tensor : node->out_tensors()) { + CHECK_NULL_RETURN(tensor); + MS_CHECK_TRUE_RET(tensor->data_c() == nullptr, RET_ERROR); + tensor->set_allocator(allocator_); + } + if (desc_.provider == kBuiltin) { + auto opencl_kernel = reinterpret_cast(node->kernel()); + std::set pre_init_weight_list = {schema::PrimitiveType_MatMul, schema::PrimitiveType_BiasAdd}; + if (pre_init_weight_list.find(opencl_kernel->type()) != pre_init_weight_list.end()) { + auto ret = opencl_kernel->InitWeights(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "init weights " << node->name() << " failed"; + return ret; + } } } - if (opencl_kernel->InferShapeDone()) { + if (node->InferShapeDone()) { auto ret = node->Prepare(); if (ret != RET_OK) { MS_LOG(ERROR) << "prepare node " << node->name() << " failed"; @@ -382,10 +389,9 @@ int OpenCLSubGraph::ReSize(bool interrupt) { } } for (auto kernel : nodes_) { - auto opencl_kernel = reinterpret_cast(kernel->kernel()); - auto ret = opencl_kernel->ReSize(); + auto ret = kernel->ReSize(); if (ret != RET_OK) { - MS_LOG(WARNING) << "ReSize " << opencl_kernel->name() << "failed!"; + MS_LOG(WARNING) << "ReSize " << kernel->name() << "failed!"; if (interrupt) { return ret; } else { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h index c48bb23a0e9..361d716c870 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h @@ -81,7 +81,7 @@ class OpenCLSubGraph : public SubGraphKernel { std::vector in_convert_ops_; std::vector out_convert_ops_; std::set nodes_set_; - lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; + lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap_; lite::opencl::OpenCLRuntime *ocl_runtime_{nullptr}; bool all_kernels_infer_done_ = false; }; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 5b8df30ef1e..b82bfc7426c 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -1163,6 +1163,9 @@ kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel, con auto desc = kernel->desc(); if (desc.provider != kernel::kBuiltin) { + if (desc.arch == kernel::KERNEL_ARCH::kGPU) { + return kernel::kGpuSubGraph; + } return kernel::kCustomSubGraph; } if (desc.arch == kernel::KERNEL_ARCH::kGPU) { diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index e302d9746ef..7a1ba529030 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -77,14 +77,8 @@ Tensor *Tensor::CopyTensor(const Tensor &src_tensor, bool copy_data, AllocatorPt } Tensor::~Tensor() { - if (this->data_ != nullptr && this->own_data_) { - if (this->allocator_ != nullptr) { - this->allocator_->Free(this->data_); - } else { - free(this->data_); - } - this->data_ = nullptr; - } + FreeData(); + this->data_ = nullptr; } bool Tensor::operator==(const Tensor &tensor) { @@ -304,18 +298,14 @@ int Tensor::MallocData(const AllocatorPtr allocator) { } void Tensor::FreeData() { - if (this->data_ == nullptr) { - return; - } - if (!this->own_data_) { - return; - } - if (allocator_ == nullptr) { - free(this->data_); - this->data_ = nullptr; - } else { - allocator_->Free(this->data_); - if (!IS_STATIC_ALLOCATOR(allocator_) || (allocator_->RefCount(this->data_) != 0)) { + if (this->data_ != nullptr && this->own_data_) { + if (this->allocator_ != nullptr) { + this->allocator_->Free(this->data_); + if (!IS_STATIC_ALLOCATOR(allocator_) || (allocator_->RefCount(this->data_) != 0)) { + this->data_ = nullptr; + } + } else { + free(this->data_); this->data_ = nullptr; } } diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index e00a11abd1d..395bc3369eb 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -168,7 +168,7 @@ class Tensor : public mindspore::tensor::MSTensor { void set_quant_clusters(const std::vector &clusters); - virtual bool IsConst() const { + bool IsConst() const override { return (this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR) && this->data_ != nullptr; } diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 26f0344a9fe..c8248c6114d 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -42,6 +42,7 @@ endif() if(MSLITE_GPU_BACKEND STREQUAL opencl) file(GLOB_RECURSE TEST_GPU_UT_SRC ${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc + ${TEST_DIR}/ut/src/registry/registry_gpu_custom_op_test.cc ) list(APPEND TEST_UT_SRC ${TEST_GPU_UT_SRC}) endif() diff --git a/mindspore/lite/test/config/ut_arm64.cfg b/mindspore/lite/test/config/ut_arm64.cfg index d2fb77e7ba0..e9868a888f5 100644 --- a/mindspore/lite/test/config/ut_arm64.cfg +++ b/mindspore/lite/test/config/ut_arm64.cfg @@ -146,4 +146,5 @@ MindrtRuntimeTest.Runtime MindrtRuntimeTest.RuntimeFp16 MixDataTypeTest.mix1 SchedulerTest.TestScheduleInt32OpToFp16Subgraph +TestGPURegistryCustomOp.TestGPUCustomAdd diff --git a/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc b/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc new file mode 100644 index 00000000000..7205fab789b --- /dev/null +++ b/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc @@ -0,0 +1,530 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "common/common_test.h" +#include "include/api/context.h" +#include "include/api/model.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "src/lite_session.h" +#include "include/registry/register_kernel_interface.h" +#include "include/registry/register_kernel.h" +#include "include/registry/opencl_runtime_wrapper.h" +#include "include/api/data_type.h" + +using mindspore::kernel::Kernel; +using mindspore::kernel::KernelInterface; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::schema::PrimitiveType_AddFusion; +#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y)) +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define C4NUM 4 + +namespace mindspore { +namespace { +constexpr auto kFloat32 = DataType::kNumberTypeFloat32; +static const char *arithmetic_source = + "\n" + "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" + "__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" + "\n" + "__kernel void ElementAdd(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t " + "output,\n" + " const int2 output_shape) {\n" + " int X = get_global_id(0);\n" + " int Y = get_global_id(1);\n" + " if (X >= output_shape.x || Y >= output_shape.y) {\n" + " return;\n" + " }\n" + "\n" + " FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" + " FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" + " FLT4 result = a + b;\n" + "\n" + " WRITE_IMAGE(output, (int2)(X, Y), result);\n" + "}\n"; + +template +void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num) { + if (src == nullptr || src_num <= 0) { + return; + } + auto *N = dst; + auto *H = dst + 1; + auto *W = dst + 2; + auto *C = dst + 3; + if (src_num == 1) { // 1 1 1 C + *C = src[0]; + } else if (src_num == 2) { // N 1 1 C + *N = src[0]; + *C = src[1]; + } else if (src_num == 3) { // N 1 W C + *N = src[0]; + *W = src[1]; + *C = src[2]; + } else if (src_num == 4) { // N H W C + *N = src[0]; + *H = src[1]; + *W = src[2]; + *C = src[3]; + } else if (src_num > 4) { + std::cerr << "GPU doesn't support ndim>=" << src_num; + } +} + +template +void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_value) { + for (int i = 0; i < 4; ++i) { + dst[i] = default_value; + } + if (src == nullptr || src_num <= 0) { + return; + } + Broadcast2GpuShape(dst, src, src_num); +} +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define C4NUM 4 +struct GpuTensorInfo { + GpuTensorInfo() = default; + explicit GpuTensorInfo(const MSTensor *tensor, registry::opencl::OpenCLRuntimeWrapper *opencl_run) { + if (tensor == nullptr) { + return; + } + auto shape_ori = tensor->Shape(); + int64_t shape[4]; + Broadcast2GpuShape(shape, shape_ori.data(), shape_ori.size(), 1l); + N = shape[0]; + H = shape[1]; + W = shape[2]; + C = shape[3]; + Slice = UP_DIV(C, C4NUM); + if (tensor->DataType() == mindspore::DataType::kNumberTypeFloat16) { + FLT_size = sizeof(cl_half); + } else { + FLT_size = sizeof(cl_float); + } + FLT4_size = FLT_size * 4; + if (W * Slice <= opencl_run->GetMaxImage2DWidth()) { + height = N * H; + width = W * Slice; + } else { + height = N * H * W; + width = Slice; + if (height > opencl_run->GetMaxImage2DHeight()) { + height = -1; + width = -1; + } + } + + ElementsNum = N * H * W * C; + Image2DSize = height * width * FLT4_size; + } + size_t N{1}; + size_t H{1}; + size_t W{1}; + size_t C{1}; + size_t Slice{}; + size_t width{}; + size_t height{}; + size_t FLT_size{4}; + size_t FLT4_size{16}; + size_t ElementsNum{}; + size_t Image2DSize{}; +}; +} // namespace + +class CustomAddKernel : public kernel::Kernel { + public: + CustomAddKernel(const std::vector &inputs, const std::vector &outputs, + const schema::Primitive *primitive, const mindspore::Context *ctx, const std::string &build_options, + bool fp16_enable) + : Kernel(inputs, outputs, primitive, ctx), build_options_(build_options), fp16_enable_(fp16_enable) { + opencl_runtime_ = new registry::opencl::OpenCLRuntimeWrapper(); + } + ~CustomAddKernel() override { FreeWeight(); } + // Prepare will be called during graph compilation + int Prepare() override { + const std::string kernel_name_ = "ElementAdd"; + const std::string program_name = "Arithmetic"; + std::string source = arithmetic_source; + if (opencl_runtime_->LoadSource(program_name, source) != kSuccess) { + std::cerr << "Load source failed."; + return lite::RET_ERROR; + } + std::vector build_options_ext = {"-cl-mad-enable -cl-fast-relaxed-math -Werror"}; + + build_options_ext.push_back(build_options_); + if (opencl_runtime_->BuildKernel(&kernel_, program_name, kernel_name_, build_options_ext) != kSuccess) { + std::cerr << "Build kernel failed."; + return lite::RET_ERROR; + } + + auto out_shape = GpuTensorInfo(&outputs_[0], opencl_runtime_); + local_range_ = cl::NullRange; + global_range_ = cl::NDRange(out_shape.width, out_shape.height); + for (int i = 0; i < inputs_.size(); ++i) { + auto &in_tensor = inputs_.at(i); + GpuTensorInfo in_shape = GpuTensorInfo(&in_tensor, opencl_runtime_); + if (in_tensor.IsConst()) { + std::vector weight(in_shape.Image2DSize, 0); + bool src_is_fp16 = in_tensor.DataType() == mindspore::DataType::kNumberTypeFloat16; + PackNHWCToNHWC4(in_tensor.MutableData(), weight.data(), src_is_fp16, fp16_enable_, in_shape, + in_tensor.DataType()); + DataType dtype = + fp16_enable_ ? mindspore::DataType::kNumberTypeFloat16 : mindspore::DataType::kNumberTypeFloat32; + auto allocator = opencl_runtime_->GetAllocator(); + if (allocator == nullptr) { + std::cerr << "GetAllocator fail."; + FreeWeight(); + return lite::RET_ERROR; + } + auto weight_ptr = allocator->Malloc(in_shape.width, in_shape.height, dtype); + if (weight_ptr == nullptr) { + std::cerr << "Malloc fail."; + FreeWeight(); + return lite::RET_ERROR; + } + weight_ptrs_.push_back(weight_ptr); + if (opencl_runtime_->WriteImage(weight_ptr, weight.data()) != kSuccess) { + std::cerr << "WriteImage fail."; + FreeWeight(); + return lite::RET_ERROR; + } + } else { + weight_ptrs_.push_back(nullptr); + } + } + + int arg_idx = 3; + cl_int2 output_shape{static_cast(global_range_[0]), static_cast(global_range_[1])}; + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx, output_shape) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx << "failed."; + FreeWeight(); + return lite::RET_ERROR; + } + + std::cout << kernel_name_ << " Init Done!" << std::endl; + return lite::RET_OK; + } + + // Execute is called to compute. + int Execute() override { + if (inputs_.size() != 2) { + return lite::RET_PARAM_INVALID; + } + PreProcess(); + std::cout << this->name() << " Running!" << std::endl; + auto input_0_ptr = weight_ptrs_[0] == nullptr ? inputs_[0].MutableData() : weight_ptrs_[0]; + auto input_1_ptr = weight_ptrs_[1] == nullptr ? inputs_[1].MutableData() : weight_ptrs_[1]; + int arg_idx = 0; + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_0_ptr) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx - 1 << "failed."; + return lite::RET_ERROR; + } + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_1_ptr) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx - 1 << "failed."; + return lite::RET_ERROR; + } + if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, outputs_[0].MutableData()) != kSuccess) { + std::cerr << "Set kernel arg" << arg_idx - 1 << "failed."; + return lite::RET_ERROR; + } + if (opencl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_) != kSuccess) { + std::cerr << "Run kernel failed."; + return lite::RET_ERROR; + } + + return lite::RET_OK; + } + + int CheckSpecs() { + for (auto &tensor : inputs_) { + if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) { + std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 input"; + return lite::RET_ERROR; + } + } + for (auto &tensor : outputs_) { + if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) { + std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 output"; + return lite::RET_ERROR; + } + } + + if (inputs_.size() != 2 || outputs_.size() != 1) { + std::cerr << "in size: " << inputs_.size() << ", out size: " << outputs_.size(); + return lite::RET_ERROR; + } + + return lite::RET_OK; + } + + // Resize is used to update some parameters if current node can change along with inputs. + int ReSize() override { + if (CheckOutputs(outputs_) == lite::RET_OK) { + return lite::RET_OK; + } + auto status = + registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_); + if (status != kSuccess) { + std::cerr << "infer failed." << std::endl; + return lite::RET_ERROR; + } + auto ret = CheckSpecs(); + if (ret != lite::RET_OK) { + std::cerr << "ReSize failed for check kernel specs!"; + return ret; + } + ret = Prepare(); + if (ret != lite::RET_OK) { + std::cerr << "ReSize failed for kernel prepare!"; + return ret; + } + return lite::RET_OK; + } + + private: + std::string build_options_; + bool fp16_enable_; + cl::Kernel kernel_; + cl::Event event_; + cl::NDRange global_range_{cl::NullRange}; + cl::NDRange local_range_{cl::NullRange}; + std::vector weight_ptrs_; + registry::opencl::OpenCLRuntimeWrapper *opencl_runtime_; + + int PreProcess() { + int ret; + ret = ReSize(); + if (ret != lite::RET_OK) { + return ret; + } + for (auto i = 0; i < outputs_.size(); ++i) { + auto *output = &outputs_.at(i); + auto img_info = GpuTensorInfo(output, opencl_runtime_); + auto allocator = output->allocator(); + if (allocator == nullptr) { + std::cerr << "The output tensor of OpenCL kernel must have an allocator."; + return lite::RET_ERROR; + } + auto data_ptr = allocator->Malloc(img_info.width, img_info.height, output->DataType()); + if (data_ptr == nullptr) { + std::cerr << "Malloc data failed"; + return lite::RET_ERROR; + } + output->SetData(data_ptr); + } + return lite::RET_OK; + } + + int CheckOutputs(const std::vector &outputs) { + for (auto &output : outputs) { + auto output_shape = output.Shape(); + if (std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) { + return lite::RET_INFER_INVALID; + } + } + return lite::RET_OK; + } + + void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor, + mindspore::DataType data_type) { + auto src_fp16 = reinterpret_cast(src); + auto src_fp32 = reinterpret_cast(src); + auto src_int32 = reinterpret_cast(src); + auto dst_fp16 = reinterpret_cast(dst); + auto dst_fp32 = reinterpret_cast(dst); + auto dst_int32 = reinterpret_cast(dst); + for (int n = 0, src_idx = 0; n < tensor.N; n++) { + for (int h = 0; h < tensor.H; ++h) { + for (int w = 0; w < tensor.W; ++w) { + for (int c = 0; c < tensor.C; ++c, ++src_idx) { + int dst_idx = ((n * tensor.H + h) * tensor.W + w) * tensor.Slice * C4NUM + c; + if (data_type == mindspore::DataType::kNumberTypeInt32) { + dst_int32[dst_idx] = src_int32[src_idx]; + } else if (dst_is_fp16) { + dst_fp16[dst_idx] = src_is_fp16 ? src_fp16[src_idx] : static_cast(src_fp32[src_idx]); + } else { + dst_fp32[dst_idx] = src_is_fp16 ? static_cast(src_fp16[src_idx]) : src_fp32[src_idx]; + } + } + } + } + } + // scalar + if (tensor.ElementsNum == 1) { + if (dst_is_fp16) { + dst_fp16[3] = dst_fp16[2] = dst_fp16[1] = dst_fp16[0]; + } else { + dst_fp32[3] = dst_fp32[2] = dst_fp32[1] = dst_fp32[0]; + } + } + } + + void FreeWeight() { + auto allocator = opencl_runtime_->GetAllocator(); + if (allocator == nullptr) { + std::cerr << "GetAllocator fail."; + return; + } + for (auto &weight_ptr : weight_ptrs_) { + if (weight_ptr != nullptr) { + allocator->Free(weight_ptr); + weight_ptr = nullptr; + } + } + } +}; + +class CustomAddInfer : public kernel::KernelInterface { + public: + CustomAddInfer() = default; + ~CustomAddInfer() = default; + + Status Infer(std::vector *inputs, std::vector *outputs, + const schema::Primitive *primitive) override { + (*outputs)[0].SetFormat((*inputs)[0].format()); + (*outputs)[0].SetDataType((*inputs)[0].DataType()); + (*outputs)[0].SetShape((*inputs)[0].Shape()); + return kSuccess; + } +}; + +namespace { +std::shared_ptr CustomAddCreator(const std::vector &inputs, + const std::vector &outputs, + const schema::Primitive *primitive, const mindspore::Context *ctx) { + const std::string build_options = " -DFLT4=float4 -DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef "; + bool fp16_enable = false; + + std::cout << "using fp32 add.\n" << std::endl; + return std::make_shared(inputs, outputs, primitive, ctx, build_options, fp16_enable); +} + +std::shared_ptr CustomAddInferCreator() { return std::make_shared(); } +} // namespace + +REGISTER_CUSTOM_KERNEL_INTERFACE(BuiltInTest, Custom_Add, CustomAddInferCreator) +// Register custom “Custom_Add” operator +REGISTER_CUSTOM_KERNEL(GPU, BuiltInTest, kFloat32, Custom_Add, CustomAddCreator) + +class TestGPURegistryCustomOp : public mindspore::CommonTest { + public: + TestGPURegistryCustomOp() = default; +}; + +TEST_F(TestGPURegistryCustomOp, TestGPUCustomAdd) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + + auto node = std::make_unique(); + node->inputIndex = {0, 1}; + node->outputIndex = {2}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Custom; + auto primitive = new schema::CustomT; + primitive->type = "Custom_Add"; + node->primitive->value.value = primitive; + node->name = "Add"; + meta_graph->nodes.emplace_back(std::move(node)); + meta_graph->inputIndex = {0, 1}; + meta_graph->outputIndex = {2}; + + auto input0 = std::make_unique(); + input0->nodeType = lite::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 28, 28, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + auto weight = std::make_unique(); + weight->nodeType = lite::NodeType_ValueNode; + weight->format = schema::Format_NHWC; + weight->dataType = TypeId::kNumberTypeFloat32; + weight->dims = {1, 28, 28, 3}; + + weight->offset = -1; + meta_graph->allTensors.emplace_back(std::move(weight)); + + auto output = std::make_unique(); + output->nodeType = lite::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->offset = -1; + meta_graph->allTensors.emplace_back(std::move(output)); + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); + builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); + size_t size = builder.GetSize(); + const char *content = reinterpret_cast(builder.GetBufferPointer()); + + // create a context + auto context = std::make_shared(); + context->SetThreadNum(1); + context->SetEnableParallel(false); + context->SetThreadAffinity(lite::HIGHER_CPU); + auto &device_list = context->MutableDeviceInfo(); + + std::shared_ptr device_info = std::make_shared(); + device_info->SetEnableFP16(false); + device_list.push_back(device_info); + + std::shared_ptr provider_gpu_device_info = std::make_shared(); + provider_gpu_device_info->SetEnableFP16(false); + provider_gpu_device_info->SetProviderDevice("GPU"); + provider_gpu_device_info->SetProvider("BuiltInTest"); + device_list.push_back(provider_gpu_device_info); + + // build a model + auto model = std::make_shared(); + auto ret = model->Build(content, size, kFlatBuffer, context); + ASSERT_EQ(kSuccess, ret.StatusCode()); + auto inputs = model->GetInputs(); + ASSERT_EQ(inputs.size(), 2); + auto inTensor = inputs.front(); + auto impl = inTensor.impl(); + ASSERT_NE(nullptr, impl); + float *in0_data = static_cast(inTensor.MutableData()); + in0_data[0] = 10.0f; + auto inTensor1 = inputs.back(); + impl = inTensor1.impl(); + ASSERT_NE(nullptr, impl); + float *in1_data = static_cast(inTensor1.MutableData()); + in1_data[0] = 20.0f; + std::vector outputs; + ret = model->Predict(inputs, &outputs); + ASSERT_EQ(kSuccess, ret.StatusCode()); + ASSERT_EQ(outputs.size(), 1); + impl = outputs.front().impl(); + ASSERT_NE(nullptr, impl); + ASSERT_EQ(28 * 28 * 3, outputs.front().ElementNum()); + ASSERT_EQ(DataType::kNumberTypeFloat32, outputs.front().DataType()); + auto *outData = reinterpret_cast(outputs.front().Data().get()); + ASSERT_NE(nullptr, outData); + ASSERT_EQ(30.0f, outData[0]); + MS_LOG(INFO) << "Register add op test pass."; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc index 159b76d7ff0..f607b93234a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc @@ -39,7 +39,7 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeInnerWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -149,7 +149,7 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeInnerWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc index f6cdbd15381..0687d3192a0 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc @@ -51,7 +51,7 @@ void TestMain(const std::vector &input_infos, const std::vec // simulating benchmark: session::LiteSession::CreateSession() -> session->Init() MS_LOG(DEBUG) << "initialize OpenCLRuntime and OpenCLAllocator"; - auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper(); auto ocl_runtime = runtime_wrapper.GetInstance(); ocl_runtime->SetFp16Enable(fp16_enable); EXPECT_TRUE(ocl_runtime->Init() == RET_OK); @@ -222,7 +222,7 @@ void TestMain(const std::vector &input_infos, std::tuple session->Init() MS_LOG(DEBUG) << "initialize OpenCLRuntime and OpenCLAllocator"; - auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper(); auto ocl_runtime = runtime_wrapper.GetInstance(); ocl_runtime->SetFp16Enable(fp16_enable); EXPECT_TRUE(ocl_runtime->Init() == RET_OK); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc index 10a153bb594..af907c015cf 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc @@ -33,7 +33,7 @@ class TestFillOpenCLCI : public mindspore::CommonTest { TEST_F(TestFillOpenCLCI, Fp32testfill) { MS_LOG(INFO) << " begin test "; - auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper(); auto runtime = runtime_wrapper.GetInstance(); runtime->Init(); auto allocator = runtime->GetAllocator(); @@ -104,7 +104,7 @@ TEST_F(TestFillOpenCLCI, Fp32testfill) { TEST_F(TestFillOpenCLCI, Fp32testshape) { MS_LOG(INFO) << " begin test "; - auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper(); auto runtime = runtime_wrapper.GetInstance(); runtime->Init(); auto allocator = runtime->GetAllocator();