This commit is contained in:
chaijun 2021-07-19 17:23:23 +08:00 committed by gongdaguo
parent 166ad9e3a7
commit c515fcd361
65 changed files with 1969 additions and 238 deletions

View File

@ -58,3 +58,6 @@
"mindspore/mindspore/lite/src/runtime/thread_pool.c" "runtime/arrays"
"mindspore/mindspore/lite/src/runtime/thread_pool.c" "runtime/int"
"mindspore/mindspore/lite/src/ops/ops_def.cc" "runtime/int"
"mindspore/mindspore/lite/examples/runtime_gpu_extend/src/cl" "legal/copyright"
"mindspore/mindspore/lite/examples/runtime_gpu_extend/src/cl" "readability/casting"
"mindspore/mindspore/lite/examples/runtime_gpu_extend/src/cl" "readability/fn_size"

View File

@ -16,13 +16,12 @@ else()
__download_pkg(OpenCL-CLHPP ${REQ_URL} ${MD5})
endif()
function(gene_opencl BASEPATH)
string(CONCAT CL_SRC_DIR "${BASEPATH}" "/src/runtime/kernel/opencl/cl")
message(STATUS "**********gene opencl*********base path: " "${BASEPATH}" ", cl path: " "${CL_SRC_DIR}")
function(gene_opencl CL_SRC_DIR)
message(STATUS "**********gene opencl********* cl path: " "${CL_SRC_DIR}")
if(NOT EXISTS ${CL_SRC_DIR})
return()
endif()
file(GLOB_RECURSE CL_LIST ${CL_SRC_DIR}/*.cl ${CL_SRC_DIR}/int8/*.cl)
file(GLOB_RECURSE CL_LIST ${CL_SRC_DIR}/*.cl)
foreach(file_path ${CL_LIST})
file(REMOVE ${file_path}.inc)
string(REGEX REPLACE ".+/(.+)\\..*" "\\1" kernel_name "${file_path}")

View File

@ -32,6 +32,15 @@ class MS_API Allocator {
/// \param[in] size Define the memory size to request.
virtual void *Malloc(size_t size) = 0;
/// \brief Method to request memory.
///
/// \param[in] weight Defines the width of memory to request
/// \param[in] height Defines the height of memory to request
/// \param[in] type Defines the data type of memory to request
virtual void *Malloc(size_t weight, size_t height, DataType type) {
return nullptr;
}
/// \brief Method to free memory.
///
/// \param[in] ptr Define the pointer of a certain memory.

View File

@ -169,6 +169,11 @@ class MS_API MSTensor {
/// \return The length of the data of the MSTensor, in bytes.
size_t DataSize() const;
/// \brief Get whether the MSTensor data is const data
///
/// \return Const flag of MSTensor
bool IsConst() const;
/// \brief Gets the boolean value that indicates whether the memory of MSTensor is on device.
///
/// \return The boolean value that indicates whether the memory of MSTensor is on device.

View File

@ -358,7 +358,8 @@ if(MSLITE_ENABLE_FP16)
endif()
if(MSLITE_GPU_BACKEND STREQUAL opencl)
add_definitions(-DGPU_OPENCL)
gene_opencl(${CMAKE_CURRENT_SOURCE_DIR})
gene_opencl(${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/kernel/opencl/cl)
gene_opencl(${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/kernel/opencl/cl/int8)
add_definitions(-DUSE_OPENCL_WRAPPER)
add_definitions(-DMS_OPENCL_PROFILE=false)
add_definitions(-DCL_TARGET_OPENCL_VERSION=200)

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLES_RUNTIME_REGISTRY_SRC_CUSTOM_COMMON_H
#define MINDSPORE_LITE_EXAMPLES_RUNTIME_REGISTRY_SRC_CUSTOM_COMMON_H
#ifndef MINDSPORE_LITE_EXAMPLES_RUNTIME_EXTEND_SRC_CUSTOM_COMMON_H
#define MINDSPORE_LITE_EXAMPLES_RUNTIME_EXTEND_SRC_CUSTOM_COMMON_H
#include <vector>
#include "include/api/types.h"

View File

@ -0,0 +1,45 @@
cmake_minimum_required(VERSION 3.14)
project(RuntimeGPUExtendTutorial)
message(STATUS "Using toolchain file: ${CMAKE_TOOLCHAIN_FILE}.")
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0)
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
endif()
add_definitions(-DCL_TARGET_OPENCL_VERSION=200)
add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=120)
add_definitions(-DCL_HPP_MINIMUM_OPENCL_VERSION=120)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../../cmake/utils.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/../../../../cmake/external_libs/opencl.cmake)
gene_opencl(${CMAKE_CURRENT_SOURCE_DIR}/src/cl)
# Add directory to include search path
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/include/third_party)
include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-headers-src/)
include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-clhpp-src/include)
# Add directory to linker search path
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/runtime/lib)
file(GLOB_RECURSE RUNTIME_REGISTRY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
add_executable(runtime_extend_tutorial ${RUNTIME_REGISTRY_SRC})
target_link_libraries(
runtime_extend_tutorial
mindspore-lite
log
)
add_executable(runtime_extend_tutorial_static ${RUNTIME_REGISTRY_SRC})
target_link_libraries(
runtime_extend_tutorial_static
-Wl,--whole-archive libmindspore-lite.a -Wl,--no-whole-archive
log
)

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
BASEPATH=$(cd "$(dirname $0)" || exit; pwd)
get_version() {
VERSION_MAJOR=$(grep "const int ms_version_major =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_MINOR=$(grep "const int ms_version_minor =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_REVISION=$(grep "const int ms_version_revision =" ${BASEPATH}/../../include/version.h | tr -dc "[0-9]")
VERSION_STR=${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION}
}
get_version
MODEL_DOWNLOAD_URL="https://download.mindspore.cn/model_zoo/official/lite/quick_start/add_extend.ms"
MODEL_DOWNLOAD_URL2="https://download.mindspore.cn/model_zoo/official/lite/quick_start/add.ms"
MINDSPORE_FILE_NAME="mindspore-lite-${VERSION_STR}-linux-x64"
MINDSPORE_FILE="${MINDSPORE_FILE_NAME}.tar.gz"
MINDSPORE_LITE_DOWNLOAD_URL="https://ms-release.obs.cn-north-4.myhuaweicloud.com/${VERSION_STR}/MindSpore/lite/release/linux/${MINDSPORE_FILE}"
mkdir -p build
mkdir -p model
if [ ! -e ${BASEPATH}/model/add_extend.ms ]; then
wget -c -O ${BASEPATH}/model/add_extend.ms --no-check-certificate ${MODEL_DOWNLOAD_URL}
fi
if [ ! -e ${BASEPATH}/model/add.ms ]; then
wget -c -O ${BASEPATH}/model/add.ms --no-check-certificate ${MODEL_DOWNLOAD_URL2}
fi
if [ ! -e ${BASEPATH}/build/${MINDSPORE_FILE} ]; then
wget -c -O ${BASEPATH}/build/${MINDSPORE_FILE} --no-check-certificate ${MINDSPORE_LITE_DOWNLOAD_URL}
fi
tar -xzf ${BASEPATH}/build/${MINDSPORE_FILE}
cp -r ${BASEPATH}/build/${MINDSPORE_FILE_NAME}/runtime ${BASEPATH}/
cd ${BASEPATH}/build || exit
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \
-DANDROID_ABI="arm64-v8a" -DCMAKE_BUILD_TYPE="Release" ${BASEPATH}
make

View File

@ -0,0 +1,200 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <random>
#include <iostream>
#include <fstream>
#include <cstring>
#include <cmath>
#include <vector>
#include <memory>
#include "include/errorcode.h"
#include "include/context.h"
#include "include/api/types.h"
#include "include/api/model.h"
namespace mindspore {
namespace lite {
namespace {
constexpr int kNumPrintOfOutData = 20;
std::string RealPath(const char *path) {
const size_t max = 4096;
if (path == nullptr) {
std::cerr << "path is nullptr" << std::endl;
return "";
}
if ((strlen(path)) >= max) {
std::cerr << "path is too long" << std::endl;
return "";
}
auto resolved_path = std::make_unique<char[]>(max);
if (resolved_path == nullptr) {
std::cerr << "new resolved_path failed" << std::endl;
return "";
}
char *real_path = realpath(path, resolved_path.get());
if (real_path == nullptr || strlen(real_path) == 0) {
std::cerr << "file path is not valid : " << path << std::endl;
return "";
}
std::string res = resolved_path.get();
return res;
}
char *ReadFile(const char *file, size_t *size) {
if (file == nullptr) {
std::cerr << "file is nullptr." << std::endl;
return nullptr;
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cerr << "file: " << file << " is not exist." << std::endl;
return nullptr;
}
if (!ifs.is_open()) {
std::cerr << "file: " << file << " open failed." << std::endl;
return nullptr;
}
ifs.seekg(0, std::ios::end);
*size = ifs.tellg();
std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
if (buf == nullptr) {
std::cerr << "malloc buf failed, file: " << file << std::endl;
ifs.close();
return nullptr;
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), *size);
ifs.close();
return buf.release();
}
} // namespace
template <typename T, typename Distribution>
void GenerateRandomData(int size, void *data, Distribution distribution) {
std::mt19937 random_engine;
int elements_num = size / sizeof(T);
(void)std::generate_n(static_cast<T *>(data), elements_num,
[&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
}
void InitMSContext(const std::shared_ptr<mindspore::Context> &context) {
context->SetThreadNum(1);
context->SetEnableParallel(false);
context->SetThreadAffinity(HIGHER_CPU);
auto &device_list = context->MutableDeviceInfo();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(false);
device_list.push_back(device_info);
std::shared_ptr<GPUDeviceInfo> provider_gpu_device_info = std::make_shared<GPUDeviceInfo>();
provider_gpu_device_info->SetEnableFP16(false);
provider_gpu_device_info->SetProviderDevice("GPU");
provider_gpu_device_info->SetProvider("Tutorial");
device_list.push_back(provider_gpu_device_info);
}
int CompileAndRun(int argc, const char **argv) {
if (argc < 2) {
std::cerr << "Model file must be provided.\n";
return RET_ERROR;
}
// Read model file.
auto model_path = RealPath(argv[1]);
if (model_path.empty()) {
std::cerr << "model path " << argv[1] << " is invalid.";
return RET_ERROR;
}
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
std::cerr << "New context failed." << std::endl;
return RET_ERROR;
}
(void)InitMSContext(context);
mindspore::Model ms_model;
size_t size = 0;
char *model_buf = ReadFile(model_path.c_str(), &size);
if (model_buf == nullptr) {
std::cerr << "Read model file failed." << std::endl;
return RET_ERROR;
}
auto ret = ms_model.Build(model_buf, size, kMindIR, context);
delete[](model_buf);
if (ret != kSuccess) {
std::cerr << "ms_model.Build failed." << std::endl;
return RET_ERROR;
}
std::vector<mindspore::MSTensor> ms_inputs_for_api = ms_model.GetInputs();
for (auto tensor : ms_inputs_for_api) {
auto input_data = tensor.MutableData();
if (input_data == nullptr) {
std::cerr << "MallocData for inTensor failed." << std::endl;
return RET_ERROR;
}
GenerateRandomData<float>(tensor.DataSize(), input_data, std::uniform_real_distribution<float>(1.0f, 1.0f));
}
std::cout << "\n------- print inputs ----------" << std::endl;
for (auto tensor : ms_inputs_for_api) {
std::cout << "in tensor name is:" << tensor.Name() << "\nin tensor size is:" << tensor.DataSize()
<< "\nin tensor elements num is:" << tensor.ElementNum() << std::endl;
auto out_data = reinterpret_cast<float *>(tensor.MutableData());
std::cout << "input data is:";
for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) {
std::cout << out_data[i] << " ";
}
std::cout << std::endl;
}
std::cout << "------- print end ----------\n" << std::endl;
std::vector<MSTensor> outputs;
auto status = ms_model.Predict(ms_inputs_for_api, &outputs);
if (status != kSuccess) {
std::cerr << "Inference error." << std::endl;
return RET_ERROR;
}
// Get Output Tensor Data.
auto out_tensors = ms_model.GetOutputs();
std::cout << "\n------- print outputs ----------" << std::endl;
for (auto tensor : out_tensors) {
std::cout << "out tensor name is:" << tensor.Name() << "\nout tensor size is:" << tensor.DataSize()
<< "\nout tensor elements num is:" << tensor.ElementNum() << std::endl;
auto out_data = reinterpret_cast<float *>(tensor.MutableData());
std::cout << "output data is:";
for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) {
std::cout << out_data[i] << " ";
}
std::cout << std::endl;
}
std::cout << "------- print end ----------\n" << std::endl;
return RET_OK;
}
} // namespace lite
} // namespace mindspore
int main(int argc, const char **argv) { return mindspore::lite::CompileAndRun(argc, argv); }

View File

@ -0,0 +1,17 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void ElementAdd(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output,
const int2 output_shape) {
int X = get_global_id(0);
int Y = get_global_id(1);
if (X >= output_shape.x || Y >= output_shape.y) {
return;
}
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));
FLT4 result = a + b;
WRITE_IMAGE(output, (int2)(X, Y), result);
}

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/custom_common.h"
#include "include/errorcode.h"
#include "include/registry/register_kernel_interface.h"
namespace mindspore {
/**
* CustomAddInfer is a child class to infer current node output's information, including format, data_type and shape.
* if inputs' shape exist -1, don't worry, which shows that shape will be inferred when running.
*/
class CustomAddInfer : public kernel::KernelInterface {
public:
CustomAddInfer() = default;
~CustomAddInfer() = default;
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType());
auto ret = custom_common::CheckInputs(*inputs);
if (ret != lite::RET_OK) {
if (ret == lite::RET_INFER_INVALID) {
(*outputs)[0].SetShape({-1}); // shape{-1} shows that shape need to be inferred when running.
return kLiteInferInvalid;
} else {
return kLiteError;
}
}
(*outputs)[0].SetShape((*inputs)[0].Shape());
return kSuccess;
}
};
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }
REGISTER_CUSTOM_KERNEL_INTERFACE(Tutorial, Custom_Add, CustomAddInferCreator)
} // namespace mindspore

View File

@ -0,0 +1,267 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <arm_neon.h>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "src/custom_common.h"
#include "include/errorcode.h"
#include "include/registry/register_kernel_interface.h"
#include "include/registry/register_kernel.h"
#include "include/registry/opencl_runtime_wrapper.h"
#include "src/cl/arithmetic.cl.inc"
#include "include/api/data_type.h"
#include "include/schema/ops_generated.h"
#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y))
namespace mindspore {
namespace custom_gpu_demo {
class CustomAddKernel : public kernel::Kernel {
public:
CustomAddKernel(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx, const std::string &build_options,
bool fp16_enable)
: Kernel(inputs, outputs, primitive, ctx), build_options_(build_options), fp16_enable_(fp16_enable) {
opencl_runtime_ = new registry::opencl::OpenCLRuntimeWrapper();
}
~CustomAddKernel() override { FreeWeight(); }
// Prepare will be called during graph compilation
int Prepare() override {
const std::string kernel_name_ = "ElementAdd";
const std::string program_name = "Arithmetic";
std::string source = arithmetic_source;
if (opencl_runtime_->LoadSource(program_name, source) != kSuccess) {
std::cerr << "Load source failed.";
return lite::RET_ERROR;
}
std::vector<std::string> build_options_ext = {"-cl-mad-enable -cl-fast-relaxed-math -Werror"};
build_options_ext.push_back(build_options_);
if (opencl_runtime_->BuildKernel(&kernel_, program_name, kernel_name_, build_options_ext) != kSuccess) {
std::cerr << "Build kernel failed.";
return lite::RET_ERROR;
}
auto out_shape = custom_common::GpuTensorInfo(&outputs_[0], opencl_runtime_);
local_range_ = cl::NullRange;
global_range_ = cl::NDRange(out_shape.width, out_shape.height);
for (int i = 0; i < inputs_.size(); ++i) {
auto &in_tensor = inputs_.at(i);
custom_common::GpuTensorInfo in_shape = custom_common::GpuTensorInfo(&in_tensor, opencl_runtime_);
if (in_tensor.IsConst()) {
std::vector<char> weight(in_shape.Image2DSize, 0);
bool src_is_fp16 = in_tensor.DataType() == mindspore::DataType::kNumberTypeFloat16;
PackNHWCToNHWC4(in_tensor.MutableData(), weight.data(), src_is_fp16, fp16_enable_, in_shape,
in_tensor.DataType());
DataType dtype =
fp16_enable_ ? mindspore::DataType::kNumberTypeFloat16 : mindspore::DataType::kNumberTypeFloat32;
auto allocator = opencl_runtime_->GetAllocator();
if (allocator == nullptr) {
std::cerr << "GetAllocator fail.";
FreeWeight();
return lite::RET_ERROR;
}
auto weight_ptr = allocator->Malloc(in_shape.width, in_shape.height, dtype);
if (weight_ptr == nullptr) {
std::cerr << "Malloc fail.";
FreeWeight();
return lite::RET_ERROR;
}
weight_ptrs_.push_back(weight_ptr);
// Use API to write GPU memory
if (opencl_runtime_->WriteImage(weight_ptr, weight.data()) != kSuccess) {
std::cerr << "WriteImage fail.";
FreeWeight();
return lite::RET_ERROR;
}
} else {
weight_ptrs_.push_back(nullptr);
}
}
int arg_idx = 3;
cl_int2 output_shape{static_cast<int>(global_range_[0]), static_cast<int>(global_range_[1])};
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx, output_shape) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx << "failed.";
FreeWeight();
return lite::RET_ERROR;
}
std::cout << kernel_name_ << " Init Done!" << std::endl;
return lite::RET_OK;
}
// Execute is called to compute.
int Execute() override {
if (inputs_.size() != 2) {
return lite::RET_PARAM_INVALID;
}
PreProcess();
std::cout << this->name() << " Running!" << std::endl;
auto input_0_ptr = weight_ptrs_[0] == nullptr ? inputs_[0].MutableData() : weight_ptrs_[0];
auto input_1_ptr = weight_ptrs_[1] == nullptr ? inputs_[1].MutableData() : weight_ptrs_[1];
int arg_idx = 0;
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_0_ptr) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx - 1 << "failed.";
return lite::RET_ERROR;
}
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_1_ptr) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx - 1 << "failed.";
return lite::RET_ERROR;
}
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, outputs_[0].MutableData()) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx - 1 << "failed.";
return lite::RET_ERROR;
}
if (opencl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_) != kSuccess) {
std::cerr << "Run kernel failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
int CheckSpecs() {
for (auto &tensor : inputs_) {
if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) {
std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 input";
return lite::RET_ERROR;
}
}
for (auto &tensor : outputs_) {
if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) {
std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 output";
return lite::RET_ERROR;
}
}
if (inputs_.size() != 2 || outputs_.size() != 1) {
std::cerr << "in size: " << inputs_.size() << ", out size: " << outputs_.size();
return lite::RET_ERROR;
}
return lite::RET_OK;
}
// Resize is used to update some parameters if current node can change along with inputs.
int ReSize() override {
if (custom_common::CheckOutputs(outputs_) == lite::RET_OK) {
return lite::RET_OK;
}
auto status =
registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_);
if (status != kSuccess) {
std::cerr << "infer failed." << std::endl;
return lite::RET_ERROR;
}
auto ret = CheckSpecs();
if (ret != lite::RET_OK) {
std::cerr << "ReSize failed for check kernel specs!";
return ret;
}
ret = Prepare();
if (ret != lite::RET_OK) {
std::cerr << "ReSize failed for kernel prepare!";
return ret;
}
return lite::RET_OK;
}
private:
std::string build_options_;
bool fp16_enable_;
cl::Kernel kernel_;
cl::Event event_;
cl::NDRange global_range_{cl::NullRange};
cl::NDRange local_range_{cl::NullRange};
std::vector<void *> weight_ptrs_;
registry::opencl::OpenCLRuntimeWrapper *opencl_runtime_;
int PreProcess() {
int ret;
ret = ReSize();
if (ret != lite::RET_OK) {
return ret;
}
for (auto i = 0; i < outputs_.size(); ++i) {
auto *output = &outputs_.at(i);
auto img_info = custom_common::GpuTensorInfo(output, opencl_runtime_);
auto allocator = output->allocator();
if (allocator == nullptr) {
std::cerr << "The output tensor of OpenCL kernel must have an allocator.";
return lite::RET_ERROR;
}
auto data_ptr = allocator->Malloc(img_info.width, img_info.height, output->DataType());
if (data_ptr == nullptr) {
std::cerr << "Malloc data failed";
return lite::RET_ERROR;
}
output->SetData(data_ptr);
}
return lite::RET_OK;
}
void FreeWeight() {
auto allocator = opencl_runtime_->GetAllocator();
if (allocator == nullptr) {
std::cerr << "GetAllocator fail.";
return;
}
for (auto &weight_ptr : weight_ptrs_) {
if (weight_ptr != nullptr) {
allocator->Free(weight_ptr);
weight_ptr = nullptr;
}
}
}
};
std::shared_ptr<kernel::Kernel> CustomAddCreator(const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx) {
const std::string build_options = " -DFLT4=float4 -DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef ";
bool fp16_enable = false;
std::cout << "using fp32 add.\n" << std::endl;
return std::make_shared<CustomAddKernel>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
}
std::shared_ptr<kernel::Kernel> CustomAddFP16Creator(const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs,
const schema::Primitive *primitive,
const mindspore::Context *ctx) {
const std::string build_options = " -DFLT4=half4 -DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh";
bool fp16_enable = true;
std::cout << "using fp16 add." << std::endl;
return std::make_shared<CustomAddKernel>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
}
} // namespace custom_gpu_demo
const auto kFloat32 = DataType::kNumberTypeFloat32;
const auto kFloat16 = DataType::kNumberTypeFloat16;
// Register custom “Custom_Add” operator
REGISTER_CUSTOM_KERNEL(GPU, Tutorial, kFloat32, Custom_Add, custom_gpu_demo::CustomAddCreator)
REGISTER_CUSTOM_KERNEL(GPU, Tutorial, kFloat16, Custom_Add, custom_gpu_demo::CustomAddFP16Creator)
using schema::PrimitiveType_AddFusion;
// Register the add operator to replace the internal add operator of MindSpore Lite
REGISTER_KERNEL(GPU, Tutorial, kFloat32, PrimitiveType_AddFusion, custom_gpu_demo::CustomAddCreator)
REGISTER_KERNEL(GPU, Tutorial, kFloat16, PrimitiveType_AddFusion, custom_gpu_demo::CustomAddFP16Creator)
} // namespace mindspore

View File

@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/custom_common.h"
namespace mindspore {
namespace custom_common {
int CheckInputs(const std::vector<mindspore::MSTensor> &inputs) {
for (auto &input : inputs) {
auto input_shape = input.Shape();
if (std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end()) {
return lite::RET_INFER_INVALID;
}
}
return lite::RET_OK;
}
int CheckOutputs(const std::vector<mindspore::MSTensor> &outputs) {
for (auto &output : outputs) {
auto output_shape = output.Shape();
if (std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) {
return lite::RET_INFER_INVALID;
}
}
return lite::RET_OK;
}
void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor,
mindspore::DataType data_type) {
auto src_fp16 = reinterpret_cast<float16_t *>(src);
auto src_fp32 = reinterpret_cast<float32_t *>(src);
auto src_int32 = reinterpret_cast<int32_t *>(src);
auto dst_fp16 = reinterpret_cast<float16_t *>(dst);
auto dst_fp32 = reinterpret_cast<float32_t *>(dst);
auto dst_int32 = reinterpret_cast<int32_t *>(dst);
for (int n = 0, src_idx = 0; n < tensor.N; n++) {
for (int h = 0; h < tensor.H; ++h) {
for (int w = 0; w < tensor.W; ++w) {
for (int c = 0; c < tensor.C; ++c, ++src_idx) {
int dst_idx = ((n * tensor.H + h) * tensor.W + w) * tensor.Slice * C4NUM + c;
if (data_type == mindspore::DataType::kNumberTypeInt32) {
dst_int32[dst_idx] = src_int32[src_idx];
} else if (dst_is_fp16) {
dst_fp16[dst_idx] = src_is_fp16 ? src_fp16[src_idx] : static_cast<float16_t>(src_fp32[src_idx]);
} else {
dst_fp32[dst_idx] = src_is_fp16 ? static_cast<float32_t>(src_fp16[src_idx]) : src_fp32[src_idx];
}
}
}
}
}
// scalar
if (tensor.ElementsNum == 1) {
if (dst_is_fp16) {
dst_fp16[3] = dst_fp16[2] = dst_fp16[1] = dst_fp16[0];
} else {
dst_fp32[3] = dst_fp32[2] = dst_fp32[1] = dst_fp32[0];
}
}
}
} // namespace custom_common
} // namespace mindspore

View File

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLES_RUNTIME_GPU_EXTEND_SRC_CUSTOM_COMMON_H
#define MINDSPORE_LITE_EXAMPLES_RUNTIME_GPU_EXTEND_SRC_CUSTOM_COMMON_H
#include <arm_neon.h>
#include <vector>
#include <iostream>
#include "include/api/types.h"
#include "include/errorcode.h"
#include "include/ms_tensor.h"
#include "include/api/data_type.h"
#include "include/registry/opencl_runtime_wrapper.h"
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define C4NUM 4
namespace mindspore {
namespace custom_common {
template <typename SrcT, typename DstT>
void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num) {
if (src == nullptr || src_num <= 0) {
return;
}
auto *N = dst;
auto *H = dst + 1;
auto *W = dst + 2;
auto *C = dst + 3;
if (src_num == 1) { // 1 1 1 C
*C = src[0];
} else if (src_num == 2) { // N 1 1 C
*N = src[0];
*C = src[1];
} else if (src_num == 3) { // N 1 W C
*N = src[0];
*W = src[1];
*C = src[2];
} else if (src_num == 4) { // N H W C
*N = src[0];
*H = src[1];
*W = src[2];
*C = src[3];
} else if (src_num > 4) {
std::cerr << "GPU doesn't support ndim>=" << src_num;
}
}
template <typename SrcT, typename DstT>
void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_value) {
for (int i = 0; i < 4; ++i) {
dst[i] = default_value;
}
if (src == nullptr || src_num <= 0) {
return;
}
Broadcast2GpuShape(dst, src, src_num);
}
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define C4NUM 4
struct GpuTensorInfo {
GpuTensorInfo() = default;
explicit GpuTensorInfo(const MSTensor *tensor, registry::opencl::OpenCLRuntimeWrapper *opencl_run) {
if (tensor == nullptr) {
return;
}
auto shape_ori = tensor->Shape();
int64_t shape[4];
Broadcast2GpuShape(shape, shape_ori.data(), shape_ori.size(), 1l);
N = shape[0];
H = shape[1];
W = shape[2];
C = shape[3];
Slice = UP_DIV(C, C4NUM);
if (tensor->DataType() == mindspore::DataType::kNumberTypeFloat16) {
FLT_size = sizeof(cl_half);
} else {
FLT_size = sizeof(cl_float);
}
FLT4_size = FLT_size * C4NUM;
if (W * Slice <= opencl_run->GetMaxImage2DWidth()) {
height = N * H;
width = W * Slice;
} else {
height = N * H * W;
width = Slice;
if (height > opencl_run->GetMaxImage2DHeight()) {
height = -1;
width = -1;
}
}
ElementsNum = N * H * W * C;
Image2DSize = height * width * FLT4_size;
}
size_t N{1};
size_t H{1};
size_t W{1};
size_t C{1};
size_t Slice{};
size_t width{};
size_t height{};
size_t FLT_size{4};
size_t FLT4_size{16};
size_t ElementsNum{};
size_t Image2DSize{};
};
// verify that the inputs' shape is inferred successfully when inferring current node.
int CheckInputs(const std::vector<mindspore::MSTensor> &inputs);
// versify that the outputs' shape is inferred successfully when running current node.
int CheckOutputs(const std::vector<mindspore::MSTensor> &inputs);
void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor,
mindspore::DataType data_type = mindspore::DataType::kNumberTypeFloat32);
} // namespace custom_common
} // namespace mindspore
#endif // MINDSPORE_LITE_EXAMPLES_RUNTIME_GPU_EXTEND_SRC_CUSTOM_COMMON_H

View File

@ -123,6 +123,11 @@ class MS_API MSTensor {
virtual Vector<lite::LiteQuantParam> quant_params() const = 0;
virtual void set_quant_params(Vector<lite::LiteQuantParam>) = 0;
/// \brief Get whether the MSTensor data is const data
///
/// \return Const flag of MSTensor
virtual bool IsConst() const = 0;
};
} // namespace tensor
} // namespace mindspore

View File

@ -0,0 +1,119 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
#define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
#include <vector>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <type_traits>
#include "CL/cl2.hpp"
#include "include/api/allocator.h"
#include "include/api/status.h"
namespace mindspore::registry::opencl {
class OpenCLRuntimeWrapper {
public:
OpenCLRuntimeWrapper() = default;
~OpenCLRuntimeWrapper() = default;
/// \brief Load the OpenCl source code and bind the program name.
///
/// \param[in] program_name Define OpenCl source program name.
/// \param[in] source Define OpenCl source.
///
/// \return Status as a status identification of loading code.
Status LoadSource(const std::string &program_name, const std::string &source);
/// \brief Building OpenCL code.
///
/// \param[in] kernel Used to return the compiled kernel
/// \param[in] program_name Define OpenCl source program name.
/// \param[in] kernel_name Define OpenCl source kernel name.
/// \param[in] build_options_ext Define OpenCl kernel build options.
///
/// \return Status as a status identification of build Kernel
Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name,
const std::vector<std::string> &build_options_ext = {});
/// \brief Set kernel argument
///
/// \param[in] kernel Define OpenCl kernel.
/// \param[in] index Define OpenCl kernel argument index.
/// \param[in] value Define OpenCl kernel argument value pointer.
/// \param[in] mem_type Define OpenCl kernel argument value memory type.
///
/// \return Status as a status identification of set kernel argument
Status SetKernelArg(const cl::Kernel &kernel, uint32_t index, void *const value);
/// \brief Set kernel argument
///
/// \param[in] kernel Define OpenCl kernel.
/// \param[in] index Define OpenCl kernel argument index.
/// \param[in] value Define OpenCl kernel argument value.
/// \param[in] mem_type Define OpenCl kernel argument value memory type.
///
/// \return Status as a status identification of set kernel argument
template <typename T>
typename std::enable_if<!std::is_pointer<T>::value, Status>::type SetKernelArg(const cl::Kernel &kernel,
uint32_t index, const T value) {
if (const_cast<cl::Kernel &>(kernel).setArg(index, value) != CL_SUCCESS) {
return kLiteError;
} else {
return kSuccess;
}
}
/// \brief Run OpenCl kernel
///
/// \param[in] kernel Define OpenCl kernel.
/// \param[in] global Define the number of work items
/// \param[in] local Define the number of work_items in a work_group
/// \param[in] command_queue Define the command queue
/// \param[in] event Define event of kernel run
///
/// \return Status as a status identification of run OpenCl kernel
Status RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local,
cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr);
/// \brief Synchronization command queue
///
/// \return Status as a status identification of synchronization command queue
Status SyncCommandQueue();
void *MapBuffer(void *host_ptr, int flags, bool sync = true);
Status UnmapBuffer(void *host_ptr);
Status ReadImage(void *buffer, void *dst_data);
Status WriteImage(void *buffer, void *src_data);
std::shared_ptr<Allocator> GetAllocator();
uint64_t DeviceMaxWorkGroupSize();
uint64_t GetMaxImage2DWidth();
uint64_t GetMaxImage2DHeight();
uint64_t GetImagePitchAlignment();
};
} // namespace mindspore::registry::opencl
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H

View File

@ -78,6 +78,7 @@ class MTensor : public mindspore::tensor::MSTensor {
void set_data(void *data) override { data_ = data; }
Vector<LiteQuantParam> quant_params() const override { return this->quant_params_; }
void set_quant_params(const Vector<LiteQuantParam> quant_params) override { this->quant_params_ = quant_params; }
bool IsConst() const override {return this->data_ != nullptr;}
private:
String tensor_name_;

View File

@ -181,6 +181,13 @@ class MSTensor::Impl {
}
return lite_tensor_->MutableData();
}
virtual bool IsConst() const {
if (lite_tensor_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor.";
return false;
}
return lite_tensor_->IsConst();
}
virtual size_t DataSize() const {
if (lite_tensor_ == nullptr) {

View File

@ -259,6 +259,14 @@ void *MSTensor::MutableData() {
return impl_->MutableData();
}
bool MSTensor::IsConst() const {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement.";
return false;
}
return impl_->IsConst();
}
size_t MSTensor::DataSize() const {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement.";

View File

@ -215,7 +215,7 @@ bool InnerContext::IsGpuFloat16Enabled() const {
if (!IsGpuEnabled()) {
return false;
}
opencl::OpenCLRuntimeWrapper wrapper;
opencl::OpenCLRuntimeInnerWrapper wrapper;
if (!wrapper.GetInstance()->GetFp16Enable()) {
return false;
}

View File

@ -47,6 +47,7 @@ namespace mindspore::lite {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
namespace {
const char *const kArchCPU = "CPU";
const char *const kArchGPU = "GPU";
void KernelKeyToKernelDesc(const KernelKey &key, KernelDesc *desc) {
MS_ASSERT(desc != nullptr);
desc->data_type = static_cast<DataType>(key.data_type);
@ -159,6 +160,8 @@ int KernelRegistry::GetCustomKernel(const std::vector<Tensor *> &in_tensors, con
kernel::KernelKey tmp_key = key;
if (desc.arch == kArchCPU) {
tmp_key.arch = kernel::kCPU;
} else if (desc.arch == kArchGPU) {
tmp_key.arch = kernel::kGPU;
} else {
tmp_key.arch = kernel::kCustom;
}

View File

@ -133,7 +133,7 @@ class LiteKernel {
}
return mindspore::lite::RET_OK;
}
bool IsBuiltin() { return desc_.provider == kBuiltin; }
virtual int ReSize() {
MS_ASSERT(kernel_ != nullptr);
return kernel_->ReSize();

View File

@ -962,9 +962,9 @@ int LiteSession::InitGPURuntime() {
}
#if GPU_OPENCL
if (this->context_->IsGpuEnabled()) {
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeInnerWrapper();
if (opencl_runtime_wrapper_ == nullptr) {
MS_LOG(ERROR) << "create OpenCLRuntimeWrapper failed";
MS_LOG(ERROR) << "create OpenCLRuntimeInnerWrapper failed";
return RET_ERROR;
}
auto gpu_device_info = this->context_->GetGpuInfo();

View File

@ -155,7 +155,7 @@ class LiteSession : public session::LiteSession {
bool is_train_session_ = false;
friend class TransferSession;
#if GPU_OPENCL
opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr};
opencl::OpenCLRuntimeInnerWrapper *opencl_runtime_wrapper_{nullptr};
#endif
std::unique_ptr<SchedulerCb> sched_cb_;
std::shared_ptr<Delegate> delegate_ = nullptr;

View File

@ -50,6 +50,7 @@ class RegistryKernelImpl {
protected:
std::map<std::string, std::unordered_map<std::string, registry::CreateKernel *>> kernel_creators_;
// keys:provider, arch, type
std::map<std::string, std::map<std::string, std::unordered_map<std::string, registry::CreateKernel *>>>
custom_kernel_creators_;

View File

@ -94,8 +94,8 @@ void *OpenCLAllocator::CreateBuffer(size_t size, void *data, size_t flags, cl::B
return host_ptr;
}
void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map,
cl::Buffer **buffer, cl::Image2D **image) {
int OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map,
cl::Buffer **buffer, cl::Image2D **image, void **host_ptr) {
cl_int ret = CL_SUCCESS;
MS_ASSERT(buffer);
MS_ASSERT(image);
@ -114,7 +114,7 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, voi
delete *buffer;
*buffer = nullptr;
MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << mindspore::kernel::CLErrorCode(ret) << ")";
return nullptr;
return RET_ERROR;
}
if (ret != CL_SUCCESS) {
delete *buffer;
@ -122,28 +122,28 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, voi
*buffer = nullptr;
*image = nullptr;
MS_LOG(ERROR) << "Create OpenCL Image2D (ERROR CODE: " << mindspore::kernel::CLErrorCode(ret) << ")";
return nullptr;
return RET_ERROR;
}
MS_LOG(DEBUG) << "Malloc a new Image2D, width=" << img_size.width << ", height=" << img_size.height;
void *host_ptr = nullptr;
if (is_map) {
std::vector<size_t> region{img_size.width, img_size.height, 1};
host_ptr = ocl_runtime_->MapBuffer(**image, true, CL_MAP_READ | CL_MAP_WRITE, region);
if (host_ptr == nullptr) {
*host_ptr = ocl_runtime_->MapBuffer(**image, true, CL_MAP_READ | CL_MAP_WRITE, region);
if (*host_ptr == nullptr) {
delete *buffer;
delete *image;
*buffer = nullptr;
*image = nullptr;
MS_LOG(ERROR) << "Map image failed, can not found image :" << *image << ", host_ptr=" << host_ptr;
return nullptr;
MS_LOG(ERROR) << "Map image failed, can not found image :" << *image << ", host_ptr=" << *host_ptr;
return RET_ERROR;
}
cl::Memory *mem = *image;
ret = ocl_runtime_->UnmapBuffer(*mem, host_ptr);
ret = ocl_runtime_->UnmapBuffer(*mem, *host_ptr);
if (ret != CL_SUCCESS) {
MS_LOG(WARNING) << "UnmapBuffer failed.";
}
}
return host_ptr;
return RET_OK;
}
int OpenCLAllocator::GetImgDtypeSize(const ImageSize &img_size) {
@ -165,6 +165,34 @@ int OpenCLAllocator::GetImgDtypeSize(const ImageSize &img_size) {
return size;
}
void *OpenCLAllocator::Malloc(size_t weight, size_t height, DataType type) {
ImageSize img_size = {weight, height};
switch (type) {
case DataType::kNumberTypeFloat32:
img_size.dtype = CL_FLOAT;
break;
case DataType::kNumberTypeFloat16:
img_size.dtype = CL_HALF_FLOAT;
break;
case DataType::kNumberTypeInt8:
img_size.dtype = CL_SIGNED_INT8;
break;
case DataType::kNumberTypeUInt8:
img_size.dtype = CL_UNSIGNED_INT8;
break;
case DataType::kNumberTypeInt32:
img_size.dtype = CL_SIGNED_INT32;
break;
case DataType::kNumberTypeUInt32:
img_size.dtype = CL_UNSIGNED_INT32;
break;
default:
MS_LOG(ERROR) << "Unsupported type " << static_cast<TypeId>(type);
return nullptr;
}
return _Malloc(MemType::IMG, nullptr, 0, img_size);
}
void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const ImageSize &img_size) {
auto svm_capabilities = ocl_runtime_->GetSVMCapabilities();
auto enable_arm_import_memory = ocl_runtime_->isExtensionEnable(EXT_ARM_IMPORT_MEMORY_HOST);
@ -208,9 +236,8 @@ void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const
UNLOCK_AND_RETURN_NULL(host_ptr == nullptr, nullptr);
}
if (mem_type == MemType::IMG) {
void *host_ptr_im = CreateImage2D(size, img_size, data, flags, data != nullptr, &buffer, &image);
UNLOCK_AND_RETURN_NULL(data != nullptr && host_ptr_im == nullptr, nullptr);
host_ptr = (data != nullptr) ? host_ptr_im : host_ptr;
auto ret = CreateImage2D(size, img_size, data, flags, data != nullptr, &buffer, &image, &host_ptr);
UNLOCK_AND_RETURN_NULL(ret != RET_OK, nullptr);
}
}
}
@ -345,17 +372,25 @@ size_t OpenCLAllocator::total_size() {
return totalSize;
}
void *OpenCLAllocator::GetImage(void *buffer) {
cl::Image2D *OpenCLAllocator::GetImage(void *buffer) {
auto it = allocated_list_.find(buffer);
if (it != allocated_list_.end()) {
return it->second->image_ptr_;
if (it->second->mem_type_ != MemType::IMG) {
return nullptr;
}
return reinterpret_cast<cl::Image2D *>(it->second->image_ptr_);
}
return nullptr;
}
void *OpenCLAllocator::GetBuffer(void *buffer) {
void *OpenCLAllocator::GetOpenclMemPtr(void *buffer, MemType *type, bool force_buffer) {
auto it = allocated_list_.find(buffer);
if (it != allocated_list_.end()) {
if ((it->second->mem_type_ == MemType::IMG) && !force_buffer) {
*type = MemType::IMG;
return it->second->image_ptr_;
}
*type = MemType::BUF;
return it->second->device_ptr_;
}
return nullptr;

View File

@ -28,6 +28,8 @@
#include "CL/cl2.hpp"
namespace mindspore::lite::opencl {
// OpenCL memory type, SHARED only valid on Mali devices.
enum class MemType : char { BUF, IMG, SHARED };
#define UNLOCK_AND_RETURN_NULL(condition, ptr) \
do { \
if (condition) { \
@ -37,7 +39,6 @@ namespace mindspore::lite::opencl {
} while (0)
class OpenCLRuntime;
enum class MemType : char { BUF, IMG, SHARED };
struct ImageSize {
size_t width = 0;
@ -57,6 +58,7 @@ class OpenCLAllocator : public mindspore::Allocator {
// malloc shared
void *Malloc(size_t size) override { return _Malloc(MemType::SHARED, nullptr, size); }
void *Malloc(size_t weight, size_t height, DataType type) override;
// malloc buffer
void *Malloc(size_t size, void *data) { return _Malloc(MemType::BUF, data, size); }
// malloc image
@ -69,8 +71,8 @@ class OpenCLAllocator : public mindspore::Allocator {
size_t total_size();
void Clear();
void *GetImage(void *host_ptr);
void *GetBuffer(void *host_ptr);
cl::Image2D *GetImage(void *host_ptr);
void *GetOpenclMemPtr(void *buffer, MemType *type, bool force_buffer = false);
void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true);
int UnmapBuffer(void *host_ptr, void *command_queue = nullptr);
MemType GetMemType(void *host_ptr);
@ -88,8 +90,8 @@ class OpenCLAllocator : public mindspore::Allocator {
void *MinimumFit(MemType mem_type, size_t size, const ImageSize &img_size);
void *_Malloc(MemType mem_type, void *data, size_t size = 0, const ImageSize &img_size = ImageSize());
void *CreateBuffer(size_t size, void *data, size_t flags, cl::Buffer **buffer);
void *CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map,
cl::Buffer **buffer, cl::Image2D **image);
int CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map, cl::Buffer **buffer,
cl::Image2D **image, void **host_ptr);
int GetImgDtypeSize(const ImageSize &img_size);
template <typename T>
void ClearMemList(T *list);

View File

@ -23,6 +23,9 @@ namespace mindspore::lite::opencl {
int OpenCLExecutor::Run(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const std::vector<kernel::LiteKernel *> &kernels, const KernelCallBack &before,
const KernelCallBack &after) {
if (before != nullptr && after != nullptr) {
ocl_runtime_.GetInstance()->SetProfiling(true);
}
return RunOrTune(inputs, outputs, kernels, before, after, false);
}
@ -30,10 +33,7 @@ int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::ve
const std::vector<kernel::LiteKernel *> &kernels, const KernelCallBack &before,
const KernelCallBack &after, bool is_tune) {
int ret{RET_OK};
auto opencl_runtime_ins = ocl_runtime.GetInstance();
if (before != nullptr && after != nullptr) {
opencl_runtime_ins->SetProfiling(true);
}
auto opencl_runtime_ins = ocl_runtime_.GetInstance();
auto profiling_tmp = opencl_runtime_ins->isProfiling();
if (is_tune) {
opencl_runtime_ins->SetProfiling(true);
@ -43,12 +43,10 @@ int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::ve
GPUCallBackParam callbackParam;
callbackParam.node_name = kernel->name();
callbackParam.node_type = kernel->type_str();
if (before != nullptr) {
if (!before(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) {
MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name();
}
if ((before != nullptr) &&
!before(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) {
MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name();
}
auto *op_kernel = reinterpret_cast<kernel::OpenCLKernel *>(kernel->kernel());
// Don't support ZeroShape
for (auto tensor : kernel->out_tensors()) {
for (size_t i = 0; i < tensor->shape().size(); i++) {
@ -58,38 +56,58 @@ int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::ve
}
}
}
if (is_tune) {
ret = op_kernel->PreProcess();
if (RET_OK != ret) {
MS_LOG(WARNING) << "PreProcess kernel failed, name: " << kernel->name() << " in tuning";
opencl_runtime_ins->SetProfiling(profiling_tmp);
return RET_OK;
}
ret = op_kernel->Tune();
if (ret != RET_OK) {
MS_LOG(ERROR) << "tuning kernel failed, name: " << kernel->name();
return ret;
if (kernel->IsBuiltin()) {
auto *op_kernel = reinterpret_cast<kernel::OpenCLKernel *>(kernel->kernel());
if (is_tune) {
ret = Tune(op_kernel);
if (ret != RET_OK) {
opencl_runtime_ins->SetProfiling(profiling_tmp);
return RET_OK;
}
} else {
ret = kernel->Execute();
if (ret != RET_OK) {
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
return ret;
}
if (profiling_tmp) {
auto execute_time = op_kernel->GetProfilingTimeMs();
MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str()
<< ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms";
callbackParam.execute_time = execute_time;
}
}
} else {
ret = kernel->Execute();
if (ret != RET_OK) {
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
return ret;
}
if (profiling_tmp) {
auto execute_time = op_kernel->GetProfilingTimeMs();
MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str()
<< ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms";
callbackParam.execute_time = execute_time;
if (!is_tune) {
ret = kernel->Execute();
if (ret != RET_OK) {
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
return ret;
}
}
}
if (after != nullptr) {
if (!after(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) {
MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name();
}
if ((after != nullptr) &&
!after(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) {
MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name();
}
}
opencl_runtime_ins->SetProfiling(profiling_tmp);
return ret;
}
int OpenCLExecutor::Tune(kernel::OpenCLKernel *op_kernel) {
auto ret = op_kernel->PreProcess();
if (ret != RET_OK) {
MS_LOG(WARNING) << "PreProcess kernel failed, name: " << op_kernel->name() << " in tuning";
return ret;
}
ret = op_kernel->Tune();
if (ret != RET_OK) {
MS_LOG(ERROR) << "tuning kernel failed, name: " << op_kernel->name();
return ret;
}
return RET_OK;
}
} // namespace mindspore::lite::opencl

View File

@ -27,7 +27,7 @@
namespace mindspore::lite::opencl {
class OpenCLExecutor : public Executor {
public:
OpenCLExecutor() : Executor() { allocator_ = ocl_runtime.GetInstance()->GetAllocator().get(); }
OpenCLExecutor() : Executor() { allocator_ = ocl_runtime_.GetInstance()->GetAllocator().get(); }
~OpenCLExecutor() override = default;
@ -43,10 +43,10 @@ class OpenCLExecutor : public Executor {
const std::vector<kernel::LiteKernel *> &kernels, const KernelCallBack &before = nullptr,
const KernelCallBack &after = nullptr, bool is_tune = false);
protected:
InnerContext *context = nullptr;
private:
int Tune(kernel::OpenCLKernel *op_kernel);
OpenCLAllocator *allocator_ = nullptr;
OpenCLRuntimeWrapper ocl_runtime;
OpenCLRuntimeInnerWrapper ocl_runtime_;
};
} // namespace mindspore::lite::opencl
#endif

View File

@ -204,6 +204,9 @@ int OpenCLRuntime::InitQueue(std::vector<cl::Platform> *platforms) {
0};
context_ =
new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, ctx_properties.data(), nullptr, nullptr, &ret);
if (context_ == nullptr || ret != CL_SUCCESS) {
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
}
#else
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
#endif
@ -334,7 +337,7 @@ cl::Device *OpenCLRuntime::Device() { return device_; }
uint64_t OpenCLRuntime::DeviceGlobalMemoryCacheSize() const { return global_memery_cachesize_; }
int OpenCLRuntime::DeviceMaxWorkGroupSize() const { return max_work_group_size_; }
uint64_t OpenCLRuntime::DeviceMaxWorkGroupSize() const { return max_work_group_size_; }
uint32_t OpenCLRuntime::DeviceComputeUnits() const { return compute_units_; }
@ -382,18 +385,24 @@ bool OpenCLRuntime::SetFp16Enable(bool enable) {
}
int OpenCLRuntime::BuildKernel(const cl::Kernel &kernel, const std::string &program_name,
const std::string &kernel_name, const std::vector<std::string> &build_options_ext) {
std::string build_option = default_build_option_;
if (fp16_enable_) {
build_option +=
" -DFP16_ENABLE=1 -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 -DUINT4=ushort4"
" -DTO_FLT=convert_half -DTO_FLT4=convert_half4";
} else {
build_option +=
" -DFP16_ENABLE=0 -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4"
" -DTO_FLT=convert_float -DTO_FLT4=convert_float4";
const std::string &kernel_name, const std::vector<std::string> &build_options_ext,
const bool is_builtin) {
std::string build_option;
if (is_builtin) {
build_option = default_build_option_;
if (fp16_enable_) {
build_option +=
" -DFP16_ENABLE=1 -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 "
"-DUINT4=ushort4"
" -DTO_FLT=convert_half -DTO_FLT4=convert_half4";
} else {
build_option +=
" -DFP16_ENABLE=0 -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 "
"-DUINT4=uint4"
" -DTO_FLT=convert_float -DTO_FLT4=convert_float4";
}
build_option += " -DMAX_IMAGE2D_WIDTH=" + std::to_string(max_image2d_width_);
}
build_option += " -DMAX_IMAGE2D_WIDTH=" + std::to_string(max_image2d_width_);
build_option =
std::accumulate(build_options_ext.begin(), build_options_ext.end(), build_option,
[](const std::string &options, const std::string &option) { return options + " " + option; });
@ -515,7 +524,7 @@ bool OpenCLRuntime::BuildProgram(const std::string &build_option, const cl::Prog
int OpenCLRuntime::ReadOrWriteImage(void *buffer, void *data, bool is_read) {
cl::CommandQueue *command_queue = profiling_ ? profiling_command_queue_ : default_command_queue_;
auto *image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(buffer));
auto *image = allocator_->GetImage(buffer);
if (image == nullptr) {
MS_LOG(WARNING) << "Can't get Image2D for " << buffer;
return RET_ERROR;

View File

@ -38,11 +38,12 @@ enum InitState { UnInit = 0, InitSuccess = 1, InitFailed = 2 };
struct GpuInfo {
GpuType type = OTHER;
};
class OpenCLRuntimeInnerWrapper;
class OpenCLRuntimeWrapper;
class OpenCLRuntime {
public:
friend OpenCLRuntimeInnerWrapper;
friend OpenCLRuntimeWrapper;
~OpenCLRuntime();
OpenCLRuntime(const OpenCLRuntime &) = delete;
OpenCLRuntime &operator=(const OpenCLRuntime &) = delete;
@ -55,7 +56,7 @@ class OpenCLRuntime {
std::shared_ptr<OpenCLAllocator> GetAllocator() { return allocator_; }
cl::CommandQueue *GetDefaultCommandQueue() { return profiling_ ? profiling_command_queue_ : default_command_queue_; }
uint64_t DeviceGlobalMemoryCacheSize() const;
int DeviceMaxWorkGroupSize() const;
uint64_t DeviceMaxWorkGroupSize() const;
uint32_t DeviceComputeUnits() const;
uint32_t DeviceMaxFreq() const;
uint64_t GetMaxWorkGroupSize(const cl::Kernel &kernel);
@ -76,50 +77,35 @@ class OpenCLRuntime {
template <typename T>
typename std::enable_if<std::is_pointer<T>::value, cl_int>::type SetKernelArg(const cl::Kernel &kernel,
uint32_t index, const T value,
const MemType mem_type = MemType::IMG) {
bool force_buffer = false) {
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr.";
return CL_INVALID_VALUE;
}
switch (mem_type) {
case MemType::BUF: {
auto svm_capabilities = GetSVMCapabilities();
if (svm_capabilities) {
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value;
return clSetKernelArgSVMPointer(kernel.get(), index, value);
}
cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetBuffer(value));
if (buffer == nullptr) {
MS_LOG(ERROR) << "buffer is nullptr.";
return CL_INVALID_VALUE;
}
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << buffer << ", host_ptr: " << value;
return const_cast<cl::Kernel &>(kernel).setArg(index, *buffer);
}
case MemType::IMG: {
cl::Image2D *image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(value));
if (image == nullptr) {
MS_LOG(WARNING) << "Can't get Image2D, try to use Buffer. Please confirm the buffer type.";
cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetBuffer(value));
if (buffer == nullptr) {
MS_LOG(ERROR) << "buffer is nullptr.";
return CL_INVALID_VALUE;
}
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << buffer << ", host_ptr: " << value;
return const_cast<cl::Kernel &>(kernel).setArg(index, *buffer);
}
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Image2D " << image << ", host_ptr: " << value;
return const_cast<cl::Kernel &>(kernel).setArg(index, *image);
}
default:
MS_LOG(ERROR) << "Unsupported opencl memory type: " << static_cast<int>(mem_type);
return CL_IMAGE_FORMAT_NOT_SUPPORTED;
auto svm_capabilities = GetSVMCapabilities();
if (svm_capabilities) {
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value;
return clSetKernelArgSVMPointer(kernel.get(), index, value);
}
lite::opencl::MemType mem_type;
void *buffer = allocator_->GetOpenclMemPtr(value, &mem_type, force_buffer);
if (buffer == nullptr) {
MS_LOG(ERROR) << "buffer is nullptr.";
return CL_INVALID_VALUE;
}
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL "
<< (mem_type == lite::opencl::MemType::IMG ? "Image " : "Buffer ") << buffer
<< ", host_ptr: " << value;
if (mem_type == lite::opencl::MemType::IMG) {
return const_cast<cl::Kernel &>(kernel).setArg(index, *reinterpret_cast<cl::Image2D *>(buffer));
} else {
return const_cast<cl::Kernel &>(kernel).setArg(index, *reinterpret_cast<cl::Buffer *>(buffer));
}
}
template <typename T>
typename std::enable_if<!std::is_pointer<T>::value, cl_int>::type SetKernelArg(
const cl::Kernel &kernel, uint32_t index, const T value, const MemType mem_type = MemType::IMG) {
typename std::enable_if<!std::is_pointer<T>::value, cl_int>::type SetKernelArg(const cl::Kernel &kernel,
uint32_t index, const T value) {
return const_cast<cl::Kernel &>(kernel).setArg(index, value);
}
@ -129,7 +115,7 @@ class OpenCLRuntime {
std::vector<unsigned char> GetProgramBinary(const cl::Program &program);
bool LoadSource(const std::string &program_name, const std::string &source);
int BuildKernel(const cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name,
const std::vector<std::string> &build_options_ext = {});
const std::vector<std::string> &build_options_ext = {}, const bool is_builtin = true);
int RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local,
cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr);
int ReadOrWriteImage(void *buffer, void *data, bool is_read);
@ -192,7 +178,7 @@ class OpenCLRuntime {
uint64_t max_alloc_size_{0};
uint64_t max_image2d_width_{0};
uint64_t max_image2d_height_{0};
int max_work_group_size_{1};
uint64_t max_work_group_size_{1};
uint32_t compute_units_{0};
uint32_t max_freq_{0};
std::string default_build_option_{"-cl-mad-enable -cl-fast-relaxed-math -Werror"};
@ -226,12 +212,12 @@ class OpenCLRuntime {
const std::string cache_version_{"V0.1"};
};
class OpenCLRuntimeWrapper {
class OpenCLRuntimeInnerWrapper {
public:
OpenCLRuntimeWrapper() { ocl_runtime_ = OpenCLRuntime::GetInstance(); }
~OpenCLRuntimeWrapper() { OpenCLRuntime::DeleteInstance(); }
OpenCLRuntimeWrapper(const OpenCLRuntimeWrapper &) = delete;
OpenCLRuntimeWrapper &operator=(const OpenCLRuntimeWrapper &) = delete;
OpenCLRuntimeInnerWrapper() { ocl_runtime_ = OpenCLRuntime::GetInstance(); }
~OpenCLRuntimeInnerWrapper() { OpenCLRuntime::DeleteInstance(); }
OpenCLRuntimeInnerWrapper(const OpenCLRuntimeInnerWrapper &) = delete;
OpenCLRuntimeInnerWrapper &operator=(const OpenCLRuntimeInnerWrapper &) = delete;
OpenCLRuntime *GetInstance() { return ocl_runtime_; }
private:

View File

@ -0,0 +1,155 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/registry/opencl_runtime_wrapper.h"
#include <dlfcn.h>
#ifdef SHARING_MEM_WITH_OPENGL
#include <EGL/egl.h>
#endif
#include <vector>
#include <numeric>
#include <utility>
#include "include/errorcode.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/gpu/opencl/opencl_allocator.h"
#include "src/common/file_utils.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"
using mindspore::kernel::CLErrorCode;
namespace mindspore::registry::opencl {
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
const std::string program_name_ext = "provider_" + program_name;
if (ocl_runtime->LoadSource(program_name_ext, source)) {
return kSuccess;
} else {
return kLiteError;
}
}
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name,
const std::string &kernel_name,
const std::vector<std::string> &build_options_ext) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
const std::string program_name_ext = "provider_" + program_name;
if (ocl_runtime->BuildKernel(*kernel, program_name_ext, kernel_name, build_options_ext, false) == RET_OK) {
return kSuccess;
} else {
return kLiteError;
}
}
Status OpenCLRuntimeWrapper::SetKernelArg(const cl::Kernel &kernel, uint32_t index, void *const value) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
if (ocl_runtime->SetKernelArg(kernel, index, value) != CL_SUCCESS) {
return kLiteError;
} else {
return kSuccess;
}
}
Status OpenCLRuntimeWrapper::RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local,
cl::CommandQueue *command_queue, cl::Event *event) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
if (ocl_runtime->RunKernel(kernel, global, local, command_queue, event) == RET_OK) {
return kSuccess;
} else {
return kLiteError;
}
}
Status OpenCLRuntimeWrapper::SyncCommandQueue() {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
if (ocl_runtime->SyncCommandQueue()) {
return kSuccess;
} else {
return kLiteError;
}
}
void *OpenCLRuntimeWrapper::MapBuffer(void *host_ptr, int flags, bool sync) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
return ocl_runtime->GetAllocator()->MapBuffer(host_ptr, flags, nullptr, sync);
}
Status OpenCLRuntimeWrapper::UnmapBuffer(void *host_ptr) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
if (ocl_runtime->GetAllocator()->UnmapBuffer(host_ptr, nullptr) == RET_OK) {
return kSuccess;
} else {
return kLiteError;
}
}
Status OpenCLRuntimeWrapper::ReadImage(void *buffer, void *dst_data) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
if (ocl_runtime->ReadImage(buffer, dst_data) == RET_OK) {
return kSuccess;
} else {
return kLiteError;
}
}
Status OpenCLRuntimeWrapper::WriteImage(void *buffer, void *src_data) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
if (ocl_runtime->WriteImage(buffer, src_data) == RET_OK) {
return kSuccess;
} else {
return kLiteError;
}
}
std::shared_ptr<Allocator> OpenCLRuntimeWrapper::GetAllocator() {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
return ocl_runtime->GetAllocator();
}
uint64_t OpenCLRuntimeWrapper::DeviceMaxWorkGroupSize() {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
return ocl_runtime->DeviceMaxWorkGroupSize();
}
uint64_t OpenCLRuntimeWrapper::GetMaxImage2DWidth() {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
return ocl_runtime->GetMaxImage2DWidth();
}
uint64_t OpenCLRuntimeWrapper::GetMaxImage2DHeight() {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
return ocl_runtime->GetMaxImage2DHeight();
}
uint64_t OpenCLRuntimeWrapper::GetImagePitchAlignment() {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
return ocl_runtime->GetImagePitchAlignment();
}
} // namespace mindspore::registry::opencl

View File

@ -68,11 +68,11 @@ int ArgMinMaxOpenCLKernel::SetConstArgs() {
static_cast<int>(im_in_.C)};
cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_};
int arg_cnt = 2;
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, ids_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, ids_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
@ -228,11 +228,11 @@ int ArgMinMaxOpenCLKernel::Prepare() {
int ArgMinMaxOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c(), lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -266,19 +266,19 @@ int BatchNormOpenCLKernel::Run() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // input tensor
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, scale_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, scale_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // scale
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, offset_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, offset_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // offset
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // mean
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, variance_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, variance_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // variance

View File

@ -36,7 +36,7 @@ int ConcatOpenCLKernel::RunAxis0() {
auto dst_data = out_tensors_[0]->data_c();
MS_ASSERT(dst_data);
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
auto *out_image = allocator_->GetImage(dst_data);
for (int i = 0; i < in_tensors_.size(); i++) {
auto src_data = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : weight_ptrs_.at(i);
if (allocator_->GetImageSize(src_data, &img_size) != RET_OK) {
@ -45,7 +45,7 @@ int ConcatOpenCLKernel::RunAxis0() {
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
auto *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
auto *input_image = allocator_->GetImage(src_data);
if (ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin,
region) != CL_SUCCESS) {
MS_LOG(WARNING) << "enqueueCopyImage failed.";
@ -290,8 +290,7 @@ int ConcatOpenCLKernel::Run() {
}
}
if (axis_ == 3 && !Align_) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) !=
CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -435,11 +435,12 @@ int Conv2DOpenCLKernel::SetConstArgs() {
cl_int2 dilation = {param_->dilation_h_, param_->dilation_w_};
int arg_cn = 2;
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, filter_type_) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, (filter_type_ == lite::opencl::MemType::BUF)) !=
CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_bias_, MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_bias_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -119,7 +119,7 @@ int Conv2dTransposeOpenCLKernel::SetConstArgs() {
cl_int2 padding = {pad_h, pad_w};
cl_int4 src_size = {h, w, UP_DIV(ci, C4NUM), n};
cl_int4 dst_size = {oh, ow, UP_DIV(co, C4NUM), n};
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, padWeight_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, padWeight_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -231,11 +231,12 @@ int DepthwiseConv2dOpenCLKernel::SetConstArgs() {
cl_int4 dst_size = {(cl_int)out_info.W, (cl_int)out_info.H, (cl_int)CO4, (cl_int)out_info.N};
int arg_cnt = 2;
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, filter_type_) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, (filter_type_ == lite::opencl::MemType::BUF)) !=
CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -43,7 +43,7 @@ int FillOpenCLKernel::RunFill() {
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
cl::Image2D *out_image = allocator_->GetImage(src_data);
if (ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region) !=
CL_SUCCESS) {
MS_LOG(ERROR) << "enqueueFillImage failed.";
@ -66,7 +66,7 @@ int FillOpenCLKernel::RunShape() {
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
cl::Image2D *out_image = allocator_->GetImage(src_data);
if (ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region) !=
CL_SUCCESS) {
MS_LOG(ERROR) << "enqueueFillImage failed.";

View File

@ -260,7 +260,7 @@ void FullConnectionOpenCLKernel::SetGlobalLocal() {
int FullConnectionOpenCLKernel::SetConstArgs() {
if (!weight_var_) {
if (ocl_runtime_->SetKernelArg(kernel_, 2, padWeight_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, 2, padWeight_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -288,8 +288,7 @@ int FusionEltwiseOpenCLKernel::SetConstArgs() {
}
}
} else {
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx, buffer_weights_[buffer_idx++], lite::opencl::MemType::BUF) !=
CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx, buffer_weights_[buffer_idx++], true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -258,7 +258,7 @@ int GatherOpenCLKernel::Run() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, 2, indices_data_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, 2, indices_data_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -254,11 +254,11 @@ int LayerNormOpenCLKernel::Run() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // input tensor
if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, mean_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, mean_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, var_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, var_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
@ -273,19 +273,19 @@ int LayerNormOpenCLKernel::Run() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // out tensor
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // mean_
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, var_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, var_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // var_
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, gamma_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, gamma_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // gamma_
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, beta_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, beta_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
} // beta_

View File

@ -268,7 +268,7 @@ int MatMulOpenCLKernel::SetConstArgs() {
if (act_weight_) {
arg_count++;
} else {
if (ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -184,7 +184,7 @@ int PReluOpenCLKernel::Run() {
return RET_ERROR;
}
} else {
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -44,7 +44,7 @@ int SparseToDenseOpenCLKernel::InitOutputToDefault() {
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
cl::Image2D *out_image = allocator_->GetImage(src_data);
if (ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region) !=
CL_SUCCESS) {
MS_LOG(ERROR) << "enqueueFillImage failed.";
@ -267,13 +267,12 @@ int SparseToDenseOpenCLKernel::Run() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) !=
CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (!weight_scalar_) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_vector_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_vector_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -33,7 +33,7 @@ int SplitOpenCLKernel::RunAxis0() {
auto allocator_ = ocl_runtime_->GetAllocator();
auto src_data = in_tensors_[0]->data_c();
CHECK_NULL_RETURN(src_data);
cl::Image2D *in_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
cl::Image2D *in_image = allocator_->GetImage(src_data);
if (in_image == nullptr) {
MS_LOG(ERROR) << "RunAxis0 in_image can not be nullptr";
return RET_ERROR;
@ -49,7 +49,7 @@ int SplitOpenCLKernel::RunAxis0() {
}
auto dst_area = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
cl::Image2D *out_image = allocator_->GetImage(dst_data);
if (out_image == nullptr) {
MS_LOG(ERROR) << "RunAxis0 out_image can not be nullptr";
return RET_ERROR;
@ -252,8 +252,7 @@ int SplitOpenCLKernel::Run() {
return RET_ERROR;
}
} else {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c(), lite::opencl::MemType::BUF) !=
CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
@ -264,7 +263,7 @@ int SplitOpenCLKernel::Run() {
return RET_ERROR;
}
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, split_sizes_, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, split_sizes_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -34,7 +34,7 @@ int StackOpenCLKernel::RunAxis0() {
auto dst_data = out_tensors_[0]->data_c();
MS_ASSERT(dst_data);
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
cl::Image2D *out_image = allocator_->GetImage(dst_data);
for (int i = 0; i < in_tensors_.size(); i++) {
auto src_data = in_tensors_[i]->data_c();
MS_ASSERT(src_data);
@ -44,7 +44,7 @@ int StackOpenCLKernel::RunAxis0() {
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
cl::Image2D *input_image = allocator_->GetImage(src_data);
if (ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin,
region) != CL_SUCCESS) {
MS_LOG(WARNING) << "enqueueCopyImage failed.";
@ -209,14 +209,12 @@ int StackOpenCLKernel::Run() {
int arg_cn = 0;
if (buffer_button_) {
for (int i = 0; i < in_tensors_.size(); ++i) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c(), lite::opencl::MemType::BUF) !=
CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF) !=
CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -249,11 +249,11 @@ int StrassenOpenCLKernel::StrassenDataFilled(cl::Kernel *kernel, void *input, vo
return RET_ERROR;
}
} else {
if (ocl_runtime_->SetKernelArg(*kernel, 0, input, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(*kernel, 0, input, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(*kernel, 1, output, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(*kernel, 1, output, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
@ -277,20 +277,20 @@ int StrassenOpenCLKernel::StrassenAddSub(cl::Kernel *kernel, void *input, void *
return RET_ERROR;
}
if (mem_type == lite::opencl::MemType::IMG) {
if (ocl_runtime_->SetKernelArg(*kernel, 0, input, lite::opencl::MemType::IMG) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(*kernel, 0, input) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(*kernel, 1, output, lite::opencl::MemType::IMG) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(*kernel, 1, output) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
} else {
if (ocl_runtime_->SetKernelArg(*kernel, 0, input, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(*kernel, 0, input, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(*kernel, 1, output, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(*kernel, 1, output, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
@ -371,7 +371,7 @@ int StrassenOpenCLKernel::StrassenRunMmatmul(void *input, void *weight, void *ou
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, 2, weight, lite::opencl::MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, 2, weight, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -108,11 +108,13 @@ int ToFormatOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
auto src_mem_type = (out_mem_type_ == MemType::IMG) ? lite::opencl::MemType::BUF : lite::opencl::MemType::IMG;
auto dst_mem_type = out_mem_type_;
if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_.front()->data_c(), src_mem_type) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_.front()->data_c(),
(src_mem_type == lite::opencl::MemType::BUF)) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_.front()->data_c(), dst_mem_type) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_.front()->data_c(),
(dst_mem_type == lite::opencl::MemType::BUF)) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -240,7 +240,8 @@ int WinogradOpenCLKernel::SetConstArgs() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, filter_type_) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, packed_filter_, (filter_type_ == lite::opencl::MemType::BUF)) !=
CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
@ -263,7 +264,7 @@ int WinogradOpenCLKernel::SetConstArgs() {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_36to4x4_, arg_cn++, packed_bias_, MemType::BUF) != CL_SUCCESS) {
if (ocl_runtime_->SetKernelArg(kernel_36to4x4_, arg_cn++, packed_bias_, true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}

View File

@ -59,7 +59,7 @@ inline bool PredIs(const LiteKernel *node, PrimitiveType type, std::vector<LiteK
if (node->in_kernels().size() == 1) {
LiteKernel *pred = node->in_kernels().front();
MS_ASSERT(pred);
if (AIsInB(pred, nodes) && pred->type() == type && pred->out_kernels().size() == 1) {
if (AIsInB(pred, nodes) && pred->type() == type && pred->out_kernels().size() == 1 && pred->IsBuiltin()) {
MS_ASSERT(pred->out_kernels().front() == node);
return true;
}
@ -578,7 +578,7 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol
// Eltwise + Eltwise
int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
if (!node->InferShapeDone()) {
if (!node->InferShapeDone() || !node->IsBuiltin()) {
return RET_ERROR;
}
MS_ASSERT(node);
@ -598,6 +598,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set
if (!pred->InferShapeDone()) {
continue;
}
if (!pred->IsBuiltin()) {
return RET_ERROR;
}
if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) {
auto *tensor = pred->out_tensors().front();
MS_ASSERT(pred->out_kernels().front() == node);
@ -627,7 +630,7 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set
}
void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
if (!node->InferShapeDone()) {
if (!node->InferShapeDone() || !node->IsBuiltin()) {
return;
}
switch (node->type()) {

View File

@ -105,7 +105,7 @@ void OpenCLKernel::PrintOutput(int print_num, const std::string &out_file) {
GpuTensorInfo img_info(tensor);
auto size = mem_type == lite::opencl::MemType::BUF ? img_info.OriginSize : img_info.Image2DSize;
std::vector<char> data(size);
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper();
auto runtime = runtime_wrapper.GetInstance();
auto allocator = runtime->GetAllocator();
if (!runtime->SyncCommandQueue()) {
@ -158,10 +158,10 @@ int OpenCLKernel::PreProcess() {
if (ret != RET_OK) {
return ret;
}
auto allocator = ocl_runtime_->GetAllocator();
for (auto i = 0; i < out_tensors_.size(); ++i) {
auto *output = out_tensors_.at(i);
MS_ASSERT(output);
CHECK_NULL_RETURN(output);
CHECK_NULL_RETURN(output->allocator());
if (GetMemType() == lite::opencl::MemType::IMG) {
ImageSize img_size;
ret = GetImageSize(i, &img_size);
@ -169,20 +169,20 @@ int OpenCLKernel::PreProcess() {
MS_LOG(ERROR) << "GetImageSize failed";
return ret;
}
auto data_ptr = allocator->Malloc(img_size);
auto data_ptr =
output->allocator()->Malloc(img_size.width, img_size.height, static_cast<enum DataType>(output->data_type()));
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
return RET_ERROR;
}
output->set_data(data_ptr);
} else {
ret = output->MallocData(allocator);
ret = output->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MallocData failed";
return ret;
}
}
output->set_allocator(allocator);
output->ResetRefCount();
}
return RET_OK;

View File

@ -92,7 +92,7 @@ void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_va
struct GpuTensorInfo {
GpuTensorInfo() = default;
explicit GpuTensorInfo(const lite::Tensor *tensor) {
auto ocl_runtime_wrap_ = lite::opencl::OpenCLRuntimeWrapper();
auto ocl_runtime_wrap_ = lite::opencl::OpenCLRuntimeInnerWrapper();
if (tensor == nullptr) {
return;
}
@ -131,7 +131,7 @@ struct GpuTensorInfo {
}
size_t RowPitch() const {
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper();
int alignment = runtime_wrapper.GetInstance()->GetImagePitchAlignment();
MS_ASSERT(alignment);
size_t row_pitch = UP_ROUND(width, alignment) * FLT4_size;
@ -238,7 +238,7 @@ class OpenCLKernel : public InnerKernel {
bool dequant_flag_{false};
private:
lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_;
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap_;
static inline std::map<std::string, BaseTuningParameter> tuned_param_cache_;
};
template <class T>

View File

@ -316,16 +316,23 @@ int OpenCLSubGraph::Prepare() {
MS_LOG(ERROR) << "node in Subgraph is nullptr";
return mindspore::lite::RET_NULL_PTR;
}
auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(node->kernel());
std::set<int> pre_init_weight_list = {schema::PrimitiveType_MatMul, schema::PrimitiveType_BiasAdd};
if (pre_init_weight_list.find(opencl_kernel->type()) != pre_init_weight_list.end()) {
auto ret = opencl_kernel->InitWeights();
if (ret != RET_OK) {
MS_LOG(ERROR) << "init weights " << node->name() << " failed";
return ret;
for (const auto tensor : node->out_tensors()) {
CHECK_NULL_RETURN(tensor);
MS_CHECK_TRUE_RET(tensor->data_c() == nullptr, RET_ERROR);
tensor->set_allocator(allocator_);
}
if (desc_.provider == kBuiltin) {
auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(node->kernel());
std::set<int> pre_init_weight_list = {schema::PrimitiveType_MatMul, schema::PrimitiveType_BiasAdd};
if (pre_init_weight_list.find(opencl_kernel->type()) != pre_init_weight_list.end()) {
auto ret = opencl_kernel->InitWeights();
if (ret != RET_OK) {
MS_LOG(ERROR) << "init weights " << node->name() << " failed";
return ret;
}
}
}
if (opencl_kernel->InferShapeDone()) {
if (node->InferShapeDone()) {
auto ret = node->Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "prepare node " << node->name() << " failed";
@ -382,10 +389,9 @@ int OpenCLSubGraph::ReSize(bool interrupt) {
}
}
for (auto kernel : nodes_) {
auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(kernel->kernel());
auto ret = opencl_kernel->ReSize();
auto ret = kernel->ReSize();
if (ret != RET_OK) {
MS_LOG(WARNING) << "ReSize " << opencl_kernel->name() << "failed!";
MS_LOG(WARNING) << "ReSize " << kernel->name() << "failed!";
if (interrupt) {
return ret;
} else {

View File

@ -81,7 +81,7 @@ class OpenCLSubGraph : public SubGraphKernel {
std::vector<LiteKernel *> in_convert_ops_;
std::vector<LiteKernel *> out_convert_ops_;
std::set<LiteKernel *> nodes_set_;
lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_;
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap_;
lite::opencl::OpenCLRuntime *ocl_runtime_{nullptr};
bool all_kernels_infer_done_ = false;
};

View File

@ -1163,6 +1163,9 @@ kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel, con
auto desc = kernel->desc();
if (desc.provider != kernel::kBuiltin) {
if (desc.arch == kernel::KERNEL_ARCH::kGPU) {
return kernel::kGpuSubGraph;
}
return kernel::kCustomSubGraph;
}
if (desc.arch == kernel::KERNEL_ARCH::kGPU) {

View File

@ -77,14 +77,8 @@ Tensor *Tensor::CopyTensor(const Tensor &src_tensor, bool copy_data, AllocatorPt
}
Tensor::~Tensor() {
if (this->data_ != nullptr && this->own_data_) {
if (this->allocator_ != nullptr) {
this->allocator_->Free(this->data_);
} else {
free(this->data_);
}
this->data_ = nullptr;
}
FreeData();
this->data_ = nullptr;
}
bool Tensor::operator==(const Tensor &tensor) {
@ -304,18 +298,14 @@ int Tensor::MallocData(const AllocatorPtr allocator) {
}
void Tensor::FreeData() {
if (this->data_ == nullptr) {
return;
}
if (!this->own_data_) {
return;
}
if (allocator_ == nullptr) {
free(this->data_);
this->data_ = nullptr;
} else {
allocator_->Free(this->data_);
if (!IS_STATIC_ALLOCATOR(allocator_) || (allocator_->RefCount(this->data_) != 0)) {
if (this->data_ != nullptr && this->own_data_) {
if (this->allocator_ != nullptr) {
this->allocator_->Free(this->data_);
if (!IS_STATIC_ALLOCATOR(allocator_) || (allocator_->RefCount(this->data_) != 0)) {
this->data_ = nullptr;
}
} else {
free(this->data_);
this->data_ = nullptr;
}
}

View File

@ -168,7 +168,7 @@ class Tensor : public mindspore::tensor::MSTensor {
void set_quant_clusters(const std::vector<float> &clusters);
virtual bool IsConst() const {
bool IsConst() const override {
return (this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR) && this->data_ != nullptr;
}

View File

@ -42,6 +42,7 @@ endif()
if(MSLITE_GPU_BACKEND STREQUAL opencl)
file(GLOB_RECURSE TEST_GPU_UT_SRC
${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc
${TEST_DIR}/ut/src/registry/registry_gpu_custom_op_test.cc
)
list(APPEND TEST_UT_SRC ${TEST_GPU_UT_SRC})
endif()

View File

@ -146,4 +146,5 @@ MindrtRuntimeTest.Runtime
MindrtRuntimeTest.RuntimeFp16
MixDataTypeTest.mix1
SchedulerTest.TestScheduleInt32OpToFp16Subgraph
TestGPURegistryCustomOp.TestGPUCustomAdd

View File

@ -0,0 +1,530 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cstring>
#include <memory>
#include "schema/inner/model_generated.h"
#include "common/common_test.h"
#include "include/api/context.h"
#include "include/api/model.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "include/registry/register_kernel_interface.h"
#include "include/registry/register_kernel.h"
#include "include/registry/opencl_runtime_wrapper.h"
#include "include/api/data_type.h"
using mindspore::kernel::Kernel;
using mindspore::kernel::KernelInterface;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::RET_PARAM_INVALID;
using mindspore::schema::PrimitiveType_AddFusion;
#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y))
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define C4NUM 4
namespace mindspore {
namespace {
constexpr auto kFloat32 = DataType::kNumberTypeFloat32;
static const char *arithmetic_source =
"\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n"
"\n"
"__kernel void ElementAdd(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t "
"output,\n"
" const int2 output_shape) {\n"
" int X = get_global_id(0);\n"
" int Y = get_global_id(1);\n"
" if (X >= output_shape.x || Y >= output_shape.y) {\n"
" return;\n"
" }\n"
"\n"
" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n"
" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n"
" FLT4 result = a + b;\n"
"\n"
" WRITE_IMAGE(output, (int2)(X, Y), result);\n"
"}\n";
template <typename SrcT, typename DstT>
void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num) {
if (src == nullptr || src_num <= 0) {
return;
}
auto *N = dst;
auto *H = dst + 1;
auto *W = dst + 2;
auto *C = dst + 3;
if (src_num == 1) { // 1 1 1 C
*C = src[0];
} else if (src_num == 2) { // N 1 1 C
*N = src[0];
*C = src[1];
} else if (src_num == 3) { // N 1 W C
*N = src[0];
*W = src[1];
*C = src[2];
} else if (src_num == 4) { // N H W C
*N = src[0];
*H = src[1];
*W = src[2];
*C = src[3];
} else if (src_num > 4) {
std::cerr << "GPU doesn't support ndim>=" << src_num;
}
}
template <typename SrcT, typename DstT>
void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_value) {
for (int i = 0; i < 4; ++i) {
dst[i] = default_value;
}
if (src == nullptr || src_num <= 0) {
return;
}
Broadcast2GpuShape(dst, src, src_num);
}
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define C4NUM 4
struct GpuTensorInfo {
GpuTensorInfo() = default;
explicit GpuTensorInfo(const MSTensor *tensor, registry::opencl::OpenCLRuntimeWrapper *opencl_run) {
if (tensor == nullptr) {
return;
}
auto shape_ori = tensor->Shape();
int64_t shape[4];
Broadcast2GpuShape(shape, shape_ori.data(), shape_ori.size(), 1l);
N = shape[0];
H = shape[1];
W = shape[2];
C = shape[3];
Slice = UP_DIV(C, C4NUM);
if (tensor->DataType() == mindspore::DataType::kNumberTypeFloat16) {
FLT_size = sizeof(cl_half);
} else {
FLT_size = sizeof(cl_float);
}
FLT4_size = FLT_size * 4;
if (W * Slice <= opencl_run->GetMaxImage2DWidth()) {
height = N * H;
width = W * Slice;
} else {
height = N * H * W;
width = Slice;
if (height > opencl_run->GetMaxImage2DHeight()) {
height = -1;
width = -1;
}
}
ElementsNum = N * H * W * C;
Image2DSize = height * width * FLT4_size;
}
size_t N{1};
size_t H{1};
size_t W{1};
size_t C{1};
size_t Slice{};
size_t width{};
size_t height{};
size_t FLT_size{4};
size_t FLT4_size{16};
size_t ElementsNum{};
size_t Image2DSize{};
};
} // namespace
class CustomAddKernel : public kernel::Kernel {
public:
CustomAddKernel(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx, const std::string &build_options,
bool fp16_enable)
: Kernel(inputs, outputs, primitive, ctx), build_options_(build_options), fp16_enable_(fp16_enable) {
opencl_runtime_ = new registry::opencl::OpenCLRuntimeWrapper();
}
~CustomAddKernel() override { FreeWeight(); }
// Prepare will be called during graph compilation
int Prepare() override {
const std::string kernel_name_ = "ElementAdd";
const std::string program_name = "Arithmetic";
std::string source = arithmetic_source;
if (opencl_runtime_->LoadSource(program_name, source) != kSuccess) {
std::cerr << "Load source failed.";
return lite::RET_ERROR;
}
std::vector<std::string> build_options_ext = {"-cl-mad-enable -cl-fast-relaxed-math -Werror"};
build_options_ext.push_back(build_options_);
if (opencl_runtime_->BuildKernel(&kernel_, program_name, kernel_name_, build_options_ext) != kSuccess) {
std::cerr << "Build kernel failed.";
return lite::RET_ERROR;
}
auto out_shape = GpuTensorInfo(&outputs_[0], opencl_runtime_);
local_range_ = cl::NullRange;
global_range_ = cl::NDRange(out_shape.width, out_shape.height);
for (int i = 0; i < inputs_.size(); ++i) {
auto &in_tensor = inputs_.at(i);
GpuTensorInfo in_shape = GpuTensorInfo(&in_tensor, opencl_runtime_);
if (in_tensor.IsConst()) {
std::vector<char> weight(in_shape.Image2DSize, 0);
bool src_is_fp16 = in_tensor.DataType() == mindspore::DataType::kNumberTypeFloat16;
PackNHWCToNHWC4(in_tensor.MutableData(), weight.data(), src_is_fp16, fp16_enable_, in_shape,
in_tensor.DataType());
DataType dtype =
fp16_enable_ ? mindspore::DataType::kNumberTypeFloat16 : mindspore::DataType::kNumberTypeFloat32;
auto allocator = opencl_runtime_->GetAllocator();
if (allocator == nullptr) {
std::cerr << "GetAllocator fail.";
FreeWeight();
return lite::RET_ERROR;
}
auto weight_ptr = allocator->Malloc(in_shape.width, in_shape.height, dtype);
if (weight_ptr == nullptr) {
std::cerr << "Malloc fail.";
FreeWeight();
return lite::RET_ERROR;
}
weight_ptrs_.push_back(weight_ptr);
if (opencl_runtime_->WriteImage(weight_ptr, weight.data()) != kSuccess) {
std::cerr << "WriteImage fail.";
FreeWeight();
return lite::RET_ERROR;
}
} else {
weight_ptrs_.push_back(nullptr);
}
}
int arg_idx = 3;
cl_int2 output_shape{static_cast<int>(global_range_[0]), static_cast<int>(global_range_[1])};
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx, output_shape) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx << "failed.";
FreeWeight();
return lite::RET_ERROR;
}
std::cout << kernel_name_ << " Init Done!" << std::endl;
return lite::RET_OK;
}
// Execute is called to compute.
int Execute() override {
if (inputs_.size() != 2) {
return lite::RET_PARAM_INVALID;
}
PreProcess();
std::cout << this->name() << " Running!" << std::endl;
auto input_0_ptr = weight_ptrs_[0] == nullptr ? inputs_[0].MutableData() : weight_ptrs_[0];
auto input_1_ptr = weight_ptrs_[1] == nullptr ? inputs_[1].MutableData() : weight_ptrs_[1];
int arg_idx = 0;
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_0_ptr) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx - 1 << "failed.";
return lite::RET_ERROR;
}
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, input_1_ptr) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx - 1 << "failed.";
return lite::RET_ERROR;
}
if (opencl_runtime_->SetKernelArg(kernel_, arg_idx++, outputs_[0].MutableData()) != kSuccess) {
std::cerr << "Set kernel arg" << arg_idx - 1 << "failed.";
return lite::RET_ERROR;
}
if (opencl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_) != kSuccess) {
std::cerr << "Run kernel failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
int CheckSpecs() {
for (auto &tensor : inputs_) {
if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) {
std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 input";
return lite::RET_ERROR;
}
}
for (auto &tensor : outputs_) {
if (tensor.DataType() != DataType::kNumberTypeFloat32 && tensor.DataType() != DataType::kNumberTypeFloat16) {
std::cerr << "ArithmeticOpenCLKernel only support fp32/fp16 output";
return lite::RET_ERROR;
}
}
if (inputs_.size() != 2 || outputs_.size() != 1) {
std::cerr << "in size: " << inputs_.size() << ", out size: " << outputs_.size();
return lite::RET_ERROR;
}
return lite::RET_OK;
}
// Resize is used to update some parameters if current node can change along with inputs.
int ReSize() override {
if (CheckOutputs(outputs_) == lite::RET_OK) {
return lite::RET_OK;
}
auto status =
registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_);
if (status != kSuccess) {
std::cerr << "infer failed." << std::endl;
return lite::RET_ERROR;
}
auto ret = CheckSpecs();
if (ret != lite::RET_OK) {
std::cerr << "ReSize failed for check kernel specs!";
return ret;
}
ret = Prepare();
if (ret != lite::RET_OK) {
std::cerr << "ReSize failed for kernel prepare!";
return ret;
}
return lite::RET_OK;
}
private:
std::string build_options_;
bool fp16_enable_;
cl::Kernel kernel_;
cl::Event event_;
cl::NDRange global_range_{cl::NullRange};
cl::NDRange local_range_{cl::NullRange};
std::vector<void *> weight_ptrs_;
registry::opencl::OpenCLRuntimeWrapper *opencl_runtime_;
int PreProcess() {
int ret;
ret = ReSize();
if (ret != lite::RET_OK) {
return ret;
}
for (auto i = 0; i < outputs_.size(); ++i) {
auto *output = &outputs_.at(i);
auto img_info = GpuTensorInfo(output, opencl_runtime_);
auto allocator = output->allocator();
if (allocator == nullptr) {
std::cerr << "The output tensor of OpenCL kernel must have an allocator.";
return lite::RET_ERROR;
}
auto data_ptr = allocator->Malloc(img_info.width, img_info.height, output->DataType());
if (data_ptr == nullptr) {
std::cerr << "Malloc data failed";
return lite::RET_ERROR;
}
output->SetData(data_ptr);
}
return lite::RET_OK;
}
int CheckOutputs(const std::vector<mindspore::MSTensor> &outputs) {
for (auto &output : outputs) {
auto output_shape = output.Shape();
if (std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) {
return lite::RET_INFER_INVALID;
}
}
return lite::RET_OK;
}
void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor,
mindspore::DataType data_type) {
auto src_fp16 = reinterpret_cast<float16_t *>(src);
auto src_fp32 = reinterpret_cast<float32_t *>(src);
auto src_int32 = reinterpret_cast<int32_t *>(src);
auto dst_fp16 = reinterpret_cast<float16_t *>(dst);
auto dst_fp32 = reinterpret_cast<float32_t *>(dst);
auto dst_int32 = reinterpret_cast<int32_t *>(dst);
for (int n = 0, src_idx = 0; n < tensor.N; n++) {
for (int h = 0; h < tensor.H; ++h) {
for (int w = 0; w < tensor.W; ++w) {
for (int c = 0; c < tensor.C; ++c, ++src_idx) {
int dst_idx = ((n * tensor.H + h) * tensor.W + w) * tensor.Slice * C4NUM + c;
if (data_type == mindspore::DataType::kNumberTypeInt32) {
dst_int32[dst_idx] = src_int32[src_idx];
} else if (dst_is_fp16) {
dst_fp16[dst_idx] = src_is_fp16 ? src_fp16[src_idx] : static_cast<float16_t>(src_fp32[src_idx]);
} else {
dst_fp32[dst_idx] = src_is_fp16 ? static_cast<float32_t>(src_fp16[src_idx]) : src_fp32[src_idx];
}
}
}
}
}
// scalar
if (tensor.ElementsNum == 1) {
if (dst_is_fp16) {
dst_fp16[3] = dst_fp16[2] = dst_fp16[1] = dst_fp16[0];
} else {
dst_fp32[3] = dst_fp32[2] = dst_fp32[1] = dst_fp32[0];
}
}
}
void FreeWeight() {
auto allocator = opencl_runtime_->GetAllocator();
if (allocator == nullptr) {
std::cerr << "GetAllocator fail.";
return;
}
for (auto &weight_ptr : weight_ptrs_) {
if (weight_ptr != nullptr) {
allocator->Free(weight_ptr);
weight_ptr = nullptr;
}
}
}
};
class CustomAddInfer : public kernel::KernelInterface {
public:
CustomAddInfer() = default;
~CustomAddInfer() = default;
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const schema::Primitive *primitive) override {
(*outputs)[0].SetFormat((*inputs)[0].format());
(*outputs)[0].SetDataType((*inputs)[0].DataType());
(*outputs)[0].SetShape((*inputs)[0].Shape());
return kSuccess;
}
};
namespace {
std::shared_ptr<kernel::Kernel> CustomAddCreator(const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx) {
const std::string build_options = " -DFLT4=float4 -DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef ";
bool fp16_enable = false;
std::cout << "using fp32 add.\n" << std::endl;
return std::make_shared<CustomAddKernel>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
}
std::shared_ptr<kernel::KernelInterface> CustomAddInferCreator() { return std::make_shared<CustomAddInfer>(); }
} // namespace
REGISTER_CUSTOM_KERNEL_INTERFACE(BuiltInTest, Custom_Add, CustomAddInferCreator)
// Register custom “Custom_Add” operator
REGISTER_CUSTOM_KERNEL(GPU, BuiltInTest, kFloat32, Custom_Add, CustomAddCreator)
class TestGPURegistryCustomOp : public mindspore::CommonTest {
public:
TestGPURegistryCustomOp() = default;
};
TEST_F(TestGPURegistryCustomOp, TestGPUCustomAdd) {
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {0, 1};
node->outputIndex = {2};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Custom;
auto primitive = new schema::CustomT;
primitive->type = "Custom_Add";
node->primitive->value.value = primitive;
node->name = "Add";
meta_graph->nodes.emplace_back(std::move(node));
meta_graph->inputIndex = {0, 1};
meta_graph->outputIndex = {2};
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = lite::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {1, 28, 28, 3};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
auto weight = std::make_unique<schema::TensorT>();
weight->nodeType = lite::NodeType_ValueNode;
weight->format = schema::Format_NHWC;
weight->dataType = TypeId::kNumberTypeFloat32;
weight->dims = {1, 28, 28, 3};
weight->offset = -1;
meta_graph->allTensors.emplace_back(std::move(weight));
auto output = std::make_unique<schema::TensorT>();
output->nodeType = lite::NodeType_Parameter;
output->format = schema::Format_NHWC;
output->dataType = TypeId::kNumberTypeFloat32;
output->offset = -1;
meta_graph->allTensors.emplace_back(std::move(output));
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize();
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());
// create a context
auto context = std::make_shared<mindspore::Context>();
context->SetThreadNum(1);
context->SetEnableParallel(false);
context->SetThreadAffinity(lite::HIGHER_CPU);
auto &device_list = context->MutableDeviceInfo();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(false);
device_list.push_back(device_info);
std::shared_ptr<GPUDeviceInfo> provider_gpu_device_info = std::make_shared<GPUDeviceInfo>();
provider_gpu_device_info->SetEnableFP16(false);
provider_gpu_device_info->SetProviderDevice("GPU");
provider_gpu_device_info->SetProvider("BuiltInTest");
device_list.push_back(provider_gpu_device_info);
// build a model
auto model = std::make_shared<mindspore::Model>();
auto ret = model->Build(content, size, kFlatBuffer, context);
ASSERT_EQ(kSuccess, ret.StatusCode());
auto inputs = model->GetInputs();
ASSERT_EQ(inputs.size(), 2);
auto inTensor = inputs.front();
auto impl = inTensor.impl();
ASSERT_NE(nullptr, impl);
float *in0_data = static_cast<float *>(inTensor.MutableData());
in0_data[0] = 10.0f;
auto inTensor1 = inputs.back();
impl = inTensor1.impl();
ASSERT_NE(nullptr, impl);
float *in1_data = static_cast<float *>(inTensor1.MutableData());
in1_data[0] = 20.0f;
std::vector<mindspore::MSTensor> outputs;
ret = model->Predict(inputs, &outputs);
ASSERT_EQ(kSuccess, ret.StatusCode());
ASSERT_EQ(outputs.size(), 1);
impl = outputs.front().impl();
ASSERT_NE(nullptr, impl);
ASSERT_EQ(28 * 28 * 3, outputs.front().ElementNum());
ASSERT_EQ(DataType::kNumberTypeFloat32, outputs.front().DataType());
auto *outData = reinterpret_cast<const float *>(outputs.front().Data().get());
ASSERT_NE(nullptr, outData);
ASSERT_EQ(30.0f, outData[0]);
MS_LOG(INFO) << "Register add op test pass.";
}
} // namespace mindspore

View File

@ -39,7 +39,7 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou
TEST_F(TestCastSelfOpenCL, Castfp32tofp16) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
auto ocl_runtime = lite::opencl::OpenCLRuntimeInnerWrapper().GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
@ -149,7 +149,7 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) {
TEST_F(TestCastSelfOpenCL, Castfp16tofp32) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
auto ocl_runtime = lite::opencl::OpenCLRuntimeInnerWrapper().GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();

View File

@ -51,7 +51,7 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec
// simulating benchmark: session::LiteSession::CreateSession() -> session->Init()
MS_LOG(DEBUG) << "initialize OpenCLRuntime and OpenCLAllocator";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper();
auto ocl_runtime = runtime_wrapper.GetInstance();
ocl_runtime->SetFp16Enable(fp16_enable);
EXPECT_TRUE(ocl_runtime->Init() == RET_OK);
@ -222,7 +222,7 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std
// simulating benchmark: session::LiteSession::CreateSession() -> session->Init()
MS_LOG(DEBUG) << "initialize OpenCLRuntime and OpenCLAllocator";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper();
auto ocl_runtime = runtime_wrapper.GetInstance();
ocl_runtime->SetFp16Enable(fp16_enable);
EXPECT_TRUE(ocl_runtime->Init() == RET_OK);

View File

@ -33,7 +33,7 @@ class TestFillOpenCLCI : public mindspore::CommonTest {
TEST_F(TestFillOpenCLCI, Fp32testfill) {
MS_LOG(INFO) << " begin test ";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->Init();
auto allocator = runtime->GetAllocator();
@ -104,7 +104,7 @@ TEST_F(TestFillOpenCLCI, Fp32testfill) {
TEST_F(TestFillOpenCLCI, Fp32testshape) {
MS_LOG(INFO) << " begin test ";
auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper();
auto runtime_wrapper = lite::opencl::OpenCLRuntimeInnerWrapper();
auto runtime = runtime_wrapper.GetInstance();
runtime->Init();
auto allocator = runtime->GetAllocator();