add cuda backend for lite

This commit is contained in:
wandongdong 2021-01-05 20:15:07 -08:00
parent 8784aa3b2f
commit fa86129228
36 changed files with 665 additions and 111 deletions

View File

@ -341,6 +341,7 @@ checkopts()
# Parse device
# Process build option
if [[ "X$DEVICE" == "Xgpu" ]]; then
LITE_ENABLE_GPU="opencl"
ENABLE_GPU="on"
ENABLE_CPU="on"
ENABLE_MPI="on"
@ -378,6 +379,12 @@ checkopts()
ENABLE_CPU="on"
elif [[ "X$DEVICE" == "Xcpu" ]]; then
ENABLE_CPU="on"
elif [[ "X$DEVICE" == "Xopencl" ]]; then
LITE_ENABLE_GPU="opencl"
elif [[ "X$DEVICE" == "Xvulkan" ]]; then
LITE_ENABLE_GPU="vulkan"
elif [[ "X$DEVICE" == "Xcuda" ]]; then
LITE_ENABLE_GPU="cuda"
elif [[ "X$DEVICE" == "X" ]]; then
:
else
@ -520,18 +527,12 @@ build_lite()
get_version
echo "============ Start building MindSpore Lite ${VERSION_STR} ============"
LITE_ENABLE_GPU=${ENABLE_GPU}
LITE_ENABLE_NPU=${ENABLE_NPU}
if [[ "${DEVICE}" == "" && "${LITE_PLATFORM}" == "arm64" ]]; then
LITE_ENABLE_GPU="on"
LITE_ENABLE_GPU="opencl"
LITE_ENABLE_NPU="on"
fi
if [[ $1 == "arm64" && "X$DEVICE" != "Xcpu" ]]; then
LITE_ENABLE_GPU="on"
echo "start get opencl"
fi
if [ "${LITE_ENABLE_NPU}" == "on" ]; then
if [ "${LITE_PLATFORM}" == "arm64" ]; then
checkddk

View File

@ -0,0 +1,41 @@
if(ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/Vulkan-Headers/archive/v1.2.166.zip")
set(MD5 "8797a525aff953ea536ebe338a9f5ef6")
set(PKG_GIT_TAG "")
__download_pkg_with_git(Vulkan-Headers ${REQ_URL} ${PKG_GIT_TAG} ${MD5})
else()
set(REQ_URL "https://github.com/KhronosGroup/Vulkan-Headers/archive/v1.2.166.zip")
set(MD5 "91eae880a0ad9ad77c89d79b95b7399a")
__download_pkg(Vulkan-Headers ${REQ_URL} ${MD5})
endif()
function(gene_spirv BASEPATH)
string(CONCAT CL_SRC_DIR "${BASEPATH}" "/src/runtime/kernel/vulkan/glsl")
message(STATUS "**********gene spirv*********base path: " "${BASEPATH}" ", glsl path: " "${CL_SRC_DIR}")
if(NOT EXISTS ${CL_SRC_DIR})
return()
endif()
file(GLOB_RECURSE CL_LIST ${CL_SRC_DIR}/*.cl)
foreach(file_path ${CL_LIST})
file(REMOVE ${file_path}.inc)
string(REGEX REPLACE ".+/(.+)\\..*" "\\1" kernel_name "${file_path}")
set(inc_file_ex "${kernel_name}.cl.inc")
execute_process(
COMMAND bash -c "sed 's/\\\\/\\\\\\\\/g' "
COMMAND bash -c "sed 's/\\\"/\\\\\\\"/g' "
COMMAND bash -c "sed 's/$/\\\\n\\\" \\\\/' "
COMMAND bash -c "sed 's/^/\\\"/' "
WORKING_DIRECTORY ${CL_SRC_DIR}
INPUT_FILE ${file_path}
OUTPUT_FILE ${inc_file_ex}
RESULT_VARIABLE RESULT)
if(NOT RESULT EQUAL "0")
message(FATAL_ERROR "error! when generate ${inc_file_ex}")
endif()
__exec_cmd(COMMAND sed -i
"1i\\static const char *${kernel_name}_source =\\\"\\\\n\\\" \\\\"
${inc_file_ex} WORKING_DIRECTORY ${CL_SRC_DIR}
)
__exec_cmd(COMMAND sed -i "$a\\\\\;" ${inc_file_ex} WORKING_DIRECTORY ${CL_SRC_DIR})
endforeach()
endfunction()

View File

@ -17,6 +17,9 @@ option(ENABLE_FP16 "if build fp16 ops" off)
option(ENABLE_TOOLS "if build tools" on)
option(BUILD_TESTCASES "if build testcase" on)
option(SUPPORT_GPU "if support gpu" off)
option(GPU_OPENCL "if support gpu opencl" off)
option(GPU_VULKAN "if support gpu vulkan" off)
option(GPU_CUDA "if support gpu cuda" off)
option(SUPPORT_NPU "if support npu" off)
option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off)
option(BUILD_MINDDATA_EXAMPLE "" on)
@ -43,6 +46,7 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
endif()
#if(BUILD_MINDDATA STREQUAL "lite" OR BUILD_MINDDATA STREQUAL "full" OR BUILD_MINDDATA STREQUAL "wrapper")
if(SUPPORT_GPU)
set(PROCESS_UNIT gpu)
elseif(SUPPORT_NPU)
@ -114,9 +118,12 @@ include(${TOP_DIR}/cmake/utils.cmake)
include(${TOP_DIR}/cmake/dependency_utils.cmake)
include(${TOP_DIR}/cmake/dependency_securec.cmake)
include(${TOP_DIR}/cmake/external_libs/flatbuffers.cmake)
if(SUPPORT_GPU)
if(SUPPORT_GPU STREQUAL opencl)
include(${TOP_DIR}/cmake/external_libs/opencl.cmake)
endif()
if(SUPPORT_GPU STREQUAL vulkan)
include(${TOP_DIR}/cmake/external_libs/vulkan.cmake)
endif()
if(ENABLE_CONVERTER OR BUILD_MINDDATA STREQUAL "full" OR BUILD_MINDDATA STREQUAL "wrapper")
include(${TOP_DIR}/cmake/external_libs/json.cmake)
@ -157,7 +164,8 @@ endif()
if(ENABLE_FP16)
add_compile_definitions(ENABLE_FP16)
endif()
if(SUPPORT_GPU)
if(SUPPORT_GPU STREQUAL opencl)
add_definitions(-DGPU_OPENCL)
gene_opencl(${CMAKE_CURRENT_SOURCE_DIR})
add_definitions(-DUSE_OPENCL_WRAPPER)
add_definitions(-DMS_OPENCL_PROFILE=false)
@ -171,6 +179,16 @@ if(SUPPORT_GPU)
include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-headers-src/)
include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-clhpp-src/include)
endif()
if(SUPPORT_GPU STREQUAL vulkan)
add_definitions(-DGPU_VULKAN)
add_definitions(-DVK_NO_PROTOTYPES)
add_compile_definitions(SUPPORT_GPU)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/build/_deps/vulkan-headers-src/include)
endif()
if(SUPPORT_GPU STREQUAL cuda)
add_definitions(-DGPU_CUDA)
add_compile_definitions(SUPPORT_GPU)
endif()
if(WIN32)
add_compile_definitions(LITE_EXPORTS)

View File

@ -40,21 +40,37 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/huffman_decode.cc
)
if(SUPPORT_GPU)
if(SUPPORT_GPU STREQUAL opencl)
file(GLOB_RECURSE OPENCL_RUNTIME_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/opencl/*.cc
)
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_subgraph.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc
${OPENCL_RUNTIME_SRC}
)
endif()
if(SUPPORT_GPU STREQUAL vulkan)
file(GLOB VULKAN_RUNTIME_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/vulkan/*.cc
)
set(LITE_SRC
${LITE_SRC}
${VULKAN_RUNTIME_SRC}
)
endif()
if(SUPPORT_GPU STREQUAL cuda)
file(GLOB CUDA_RUNTIME_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/gpu/cuda/*.cc
)
set(LITE_SRC
${LITE_SRC}
${CUDA_RUNTIME_SRC}
)
endif()
if(SUPPORT_TRAIN)
set(ANF_SRC
${ANF_SRC}
@ -86,10 +102,14 @@ set_target_properties(mindspore-lite_static PROPERTIES OUTPUT_NAME "mindspore-li
set_target_properties(mindspore-lite_static PROPERTIES CLEAN_DIRECT_OUTPUT 1)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-private-field")
if(SUPPORT_GPU)
if(SUPPORT_GPU STREQUAL opencl)
add_subdirectory(runtime/kernel/opencl)
target_link_libraries(mindspore-lite cpu_kernel_mid opencl_kernel_mid nnacl cpu_ops_mid)
target_link_libraries(mindspore-lite_static cpu_kernel_mid opencl_kernel_mid nnacl_mid cpu_ops_mid)
elseif(SUPPORT_GPU STREQUAL cuda)
add_subdirectory(runtime/kernel/cuda)
target_link_libraries(mindspore-lite cpu_kernel_mid cuda_kernel_mid nnacl cpu_ops_mid)
target_link_libraries(mindspore-lite_static cpu_kernel_mid cuda_kernel_mid nnacl_mid cpu_ops_mid)
else()
target_link_libraries(mindspore-lite cpu_kernel_mid nnacl cpu_ops_mid)
target_link_libraries(mindspore-lite_static cpu_kernel_mid nnacl_mid cpu_ops_mid)

View File

@ -32,7 +32,7 @@
#include "src/runtime/agent/npu/npu_manager.h"
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
#endif
#if SUPPORT_GPU
#if GPU_OPENCL
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
#endif
@ -562,7 +562,7 @@ LiteSession::~LiteSession() {
mindspore::lite::NPUPassManager::GetInstance()->Clear();
mindspore::lite::NPUManager::GetInstance()->Reset();
#endif
#if SUPPORT_GPU && !SUPPORT_TRAIN
#if GPU_OPENCL && !SUPPORT_TRAIN
delete opencl_runtime_wrapper_;
#endif
delete (model_);
@ -646,7 +646,7 @@ int LiteSession::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels)
}
auto ret = RET_OK;
if (kernel->subgraph_type() == kernel::kGpuSubGraph) {
#if SUPPORT_GPU
#if GPU_OPENCL
auto sub_graph = reinterpret_cast<kernel::OpenCLSubGraph *>(kernel);
ret = sub_graph->ReSize(false);
#endif
@ -700,7 +700,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
}
int LiteSession::InitGPURuntime() {
#if SUPPORT_GPU && !SUPPORT_TRAIN
#if GPU_OPENCL && !SUPPORT_TRAIN
if (this->context_->IsGpuEnabled()) {
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
if (opencl_runtime_wrapper_ == nullptr) {
@ -717,6 +717,23 @@ int LiteSession::InitGPURuntime() {
MS_LOG(INFO) << "Init OpenCL runtime success.";
}
}
#elif GPU_VULKAN && !SUPPORT_TRAIN
if (this->context_->IsGpuEnabled()) {
auto gpu_device_info = this->context_->GetGpuInfo();
vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime>;
if (vk_runtime_wrap_ == nullptr) {
MS_LOG(ERROR) << "create vk_runtime failed";
return RET_ERROR;
}
auto vk_runtime = vk_runtime_wrap_->GetInstance();
vk_runtime->SetFp16Enable(gpu_device_info.enable_float16_);
if (vk_runtime->Init() != RET_OK) {
this->context_->device_list_ = {{DT_CPU, {gpu_device_info.enable_float16_, MID_CPU}}};
MS_LOG(WARNING) << "Init Vulkan runtime failed, change to CPU mode.";
} else {
MS_LOG(INFO) << "Init Vulkan runtime success.";
}
}
#endif
return RET_OK;
}

View File

@ -31,8 +31,10 @@
#include "src/executor.h"
#include "src/tensor.h"
#include "src/tensorlist.h"
#if SUPPORT_GPU
#include "src/runtime/opencl/opencl_runtime.h"
#if GPU_OPENCL
#include "src/runtime/gpu/opencl/opencl_runtime.h"
#elif GPU_VULKAN
#include "src/runtime/gpu/vulkan/vulkan_runtime.h"
#endif
namespace mindspore {
@ -127,8 +129,10 @@ class LiteSession : public session::LiteSession {
Executor *executor_ = nullptr;
Model *model_ = nullptr;
std::atomic<bool> is_running_ = false;
#if SUPPORT_GPU && !SUPPORT_TRAIN
#if GPU_OPENCL && !SUPPORT_TRAIN
opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr};
#elif GPU_VULKAN && !SUPPORT_TRAIN
gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime> *vk_runtime_wrap_{nullptr};
#endif
};
} // namespace lite

View File

@ -0,0 +1,21 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/gpu/cuda/cuda_device.h"
#include <unordered_set>
namespace mindspore::lite::cuda {
CudaDevice::~CudaDevice() {}
} // namespace mindspore::lite::cuda

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CUDA_DEVICE_H_
#define MINDSPORE_LITE_SRC_CUDA_DEVICE_H_
#include <vulkan/vulkan.h>
#include <assert.h>
#include <exception>
#include <algorithm>
#include "src/runtime/gpu/gpu_runtime.h"
namespace mindspore::lite::cuda {
class CudaDevice {
public:
CudaDevice() {}
virtual ~CudaDevice();
};
} // namespace mindspore::lite::cuda
#endif // MINDSPORE_LITE_SRC_CUDA_DEVICE_H_

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/gpu/cuda/cuda_runtime.h"
#include <vector>
#include <mutex>
#include "include/errorcode.h"
#include "src/common/file_utils.h"
namespace mindspore::lite::cuda {
static std::mutex g_mtx;
bool CudaRuntime::initialized_ = false;
uint32_t CudaRuntime::instance_count_ = 0;
CudaRuntime *CudaRuntime::cuda_runtime_instance_ = nullptr;
CudaRuntime *CudaRuntime::GetInstance() {
std::unique_lock<std::mutex> lck(g_mtx);
static CudaRuntime vk_runtime;
if (instance_count_ == 0) {
cuda_runtime_instance_ = &vk_runtime;
cuda_runtime_instance_->Init();
}
instance_count_++;
return cuda_runtime_instance_;
}
void CudaRuntime::DeleteInstance() {
std::unique_lock<std::mutex> lck(g_mtx);
if (instance_count_ == 0) {
MS_LOG(ERROR) << "No VulkanRuntime instance could delete!";
}
instance_count_--;
if (instance_count_ == 0) {
cuda_runtime_instance_->Uninit();
}
}
CudaRuntime::CudaRuntime() {}
// Init will get platforms info, get devices info, create opencl context.
int CudaRuntime::Init() {
if (initialized_) {
return RET_OK;
}
initialized_ = true;
MS_LOG(INFO) << "CudaRuntime init done!";
return RET_OK;
}
int CudaRuntime::Uninit() {
if (!initialized_) {
return RET_OK;
}
initialized_ = false;
return RET_OK;
}
CudaRuntime::~CudaRuntime() { Uninit(); }
const GpuInfo &CudaRuntime::GetGpuInfo() { return gpu_info_; }
bool CudaRuntime::GetFp16Enable() const { return true; }
} // namespace mindspore::lite::cuda

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CUDA_RUNTIME_H_
#define MINDSPORE_LITE_SRC_CUDA_RUNTIME_H_
#include <vector>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <type_traits>
#include "src/common/log_adapter.h"
#include "src/runtime/gpu/gpu_runtime.h"
#include "schema/gpu_cache_generated.h"
using mindspore::lite::gpu::GpuInfo;
using mindspore::lite::gpu::GpuRuntime;
using mindspore::lite::gpu::GpuRuntimeWrapper;
namespace mindspore::lite::cuda {
class CudaRuntime : public GpuRuntime {
public:
friend GpuRuntimeWrapper<CudaRuntime>;
~CudaRuntime() override;
CudaRuntime(const CudaRuntime &) = delete;
CudaRuntime &operator=(const CudaRuntime &) = delete;
int Init() override;
int Uninit() override;
const GpuInfo &GetGpuInfo() override;
bool GetFp16Enable() const override;
static CudaRuntime *GetInstance();
static void DeleteInstance();
private:
CudaRuntime();
private:
static bool initialized_;
static uint32_t instance_count_;
static CudaRuntime *cuda_runtime_instance_;
};
} // namespace mindspore::lite::cuda
#endif // MINDSPORE_LITE_SRC_CUDA_RUNTIME_H_

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/gpu/gpu_runtime.h"
#include <vector>
#include <numeric>
#include <utility>
#include <mutex>
#ifdef SHARING_MEM_WITH_OPENGL
#include <EGL/egl.h>
#endif
#include "include/errorcode.h"
#include "src/common/file_utils.h"
namespace mindspore::lite::gpu {
const GpuInfo &GpuRuntime::GetGpuInfo() { return gpu_info_; }
} // namespace mindspore::lite::gpu

View File

@ -0,0 +1,107 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_GPU_RUNTIME_H_
#define MINDSPORE_LITE_SRC_GPU_RUNTIME_H_
#include <vector>
#include <unordered_map>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <type_traits>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/runtime/allocator.h"
#include "schema/gpu_cache_generated.h"
namespace mindspore::lite::gpu {
enum GpuType { OTHER = 0, ADRENO = 1, MALI = 2, MALI_T = 3, MALI_G = 4 };
struct GpuInfo {
GpuType type = OTHER;
int model_num = 0;
float version = 0;
uint64_t global_memery_cachesize{0};
uint64_t global_memery_size{0};
uint64_t max_alloc_size{0};
uint32_t max_work_group_size{1};
uint32_t compute_units{0};
uint32_t max_freq{0};
uint32_t image_pitch_align{0};
std::vector<size_t> max_work_item_sizes;
bool support_fp16{false};
bool support_svm{false};
};
enum class GpuBackendType { OPENCL = 0, CUDA = 1, VULKAN = 2 };
class DevKey {
public:
std::string name{""};
};
class GpuContext {
public:
GpuBackendType type;
};
class GpuDevice {
public:
GpuDevice();
~GpuDevice();
};
class DevKernel {
public:
void *data{nullptr};
};
class GpuAllocator : public Allocator {};
class GpuRuntime {
public:
GpuRuntime() {}
virtual ~GpuRuntime() {}
GpuRuntime(const GpuRuntime &) = delete;
GpuRuntime &operator=(const GpuRuntime &) = delete;
virtual int Init() { return RET_ERROR; }
virtual int Uninit() { return RET_ERROR; }
virtual const GpuInfo &GetGpuInfo() = 0;
virtual bool GetFp16Enable() const = 0;
uint64_t GetGlobalMemSize() const { return gpu_info_.global_memery_size; }
uint64_t GetMaxAllocSize() const { return gpu_info_.max_alloc_size; }
const std::vector<size_t> &GetWorkItemSize() const { return gpu_info_.max_work_item_sizes; }
protected:
// gpu hal native defines
std::unordered_map<std::string, DevKernel *> dev_kernels_;
GpuContext *context_{nullptr};
GpuDevice *device_{nullptr};
GpuInfo gpu_info_;
private:
};
template <class T>
class GpuRuntimeWrapper {
public:
GpuRuntimeWrapper() { gpu_runtime_ = T::GetInstance(); }
~GpuRuntimeWrapper() { T::DeleteInstance(); }
GpuRuntimeWrapper(const GpuRuntimeWrapper &) = delete;
GpuRuntimeWrapper &operator=(const GpuRuntimeWrapper &) = delete;
T *GetInstance() { return gpu_runtime_; }
private:
T *gpu_runtime_{nullptr};
};
} // namespace mindspore::lite::gpu
#endif // MINDSPORE_LITE_SRC_GPU_RUNTIME_H_

View File

@ -14,9 +14,9 @@
* limitations under the License.
*/
#include "src/runtime/opencl/opencl_allocator.h"
#include "src/runtime/gpu/opencl/opencl_allocator.h"
#include <utility>
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/runtime/opencl/opencl_executor.h"
#include "src/runtime/gpu/opencl/opencl_executor.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "nnacl/pack.h"
#include "include/errorcode.h"
@ -27,8 +27,8 @@ int OpenCLExecutor::Run(std::vector<Tensor *> &inputs, std::vector<Tensor *> &ou
return RunOrTune(inputs, outputs, kernels, allocator, before, after, false);
}
int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor *> &outputs,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator,
int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator,
const KernelCallBack &before, const KernelCallBack &after, bool is_tune) {
int ret{RET_OK};
auto opencl_runtime_ins = ocl_runtime.GetInstance();

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_OPENCL_EXECUTOR_H_
#include <vector>
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"
#include "src/runtime/allocator.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/executor.h"
@ -34,8 +34,8 @@ class OpenCLExecutor : public Executor {
int Run(std::vector<Tensor *> &inputs, std::vector<Tensor *> &outputs, std::vector<kernel::LiteKernel *> &kernels,
Allocator *allocator = nullptr, const KernelCallBack &before = nullptr,
const KernelCallBack &after = nullptr) override;
int RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor *> &outputs,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
int RunOrTune(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr, bool is_tune = false);
protected:

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"
#include <vector>
#include <numeric>
#include <utility>
@ -23,7 +23,7 @@
#endif
#include "include/errorcode.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/opencl/opencl_allocator.h"
#include "src/runtime/gpu/opencl/opencl_allocator.h"
#include "src/common/file_utils.h"
#ifdef PROGRAM_WITH_IL
#include "src/backend/opencl/cl/program.inc"
@ -72,11 +72,12 @@ void printf_callback(const char *buffer, size_t length, size_t final, void *user
fwrite(buffer, 1, length, stdout);
}
int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> &platforms) {
int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> *platforms) {
MS_ASSERT(platforms);
// search GPU
std::vector<cl::Device> devices;
int ret = RET_OK;
for (auto &platform : platforms) {
for (auto &platform : *platforms) {
std::string platform_name;
ret = platform.getInfo(CL_PLATFORM_NAME, &platform_name);
if (ret != CL_SUCCESS) {
@ -173,7 +174,8 @@ int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> &platforms) {
return RET_OK;
}
int OpenCLRuntime::InitQueue(std::vector<cl::Platform> &platforms) {
int OpenCLRuntime::InitQueue(std::vector<cl::Platform> *platforms) {
MS_ASSERT(platforms);
cl_int ret;
#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120)
// create context from glcontext
@ -195,7 +197,7 @@ int OpenCLRuntime::InitQueue(std::vector<cl::Platform> &platforms) {
MS_LOG(INFO) << "Create common opencl context";
#ifdef Debug
std::vector<cl_context_properties> ctx_properties = {CL_CONTEXT_PLATFORM,
(cl_context_properties)platforms[0](),
(cl_context_properties)(*platforms)[0](),
CL_PRINTF_CALLBACK_ARM,
(cl_context_properties)printf_callback,
CL_PRINTF_BUFFERSIZE_ARM,
@ -258,12 +260,12 @@ int OpenCLRuntime::Init() {
MS_LOG(ERROR) << "OpenCL Platform not found!" << CLErrorCode(ret);
return RET_ERROR;
}
auto ms_ret = InitGPUDevice(platforms);
auto ms_ret = InitGPUDevice(&platforms);
if (ms_ret != RET_OK) {
return ms_ret;
}
ms_ret = InitQueue(platforms);
ms_ret = InitQueue(&platforms);
if (ms_ret != RET_OK) {
return ms_ret;
}
@ -362,8 +364,9 @@ bool OpenCLRuntime::SetFp16Enable(bool enable) {
return fp16_enable_ == enable;
}
int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name,
const std::vector<std::string> &build_options_ext, TypeId data_type) {
int OpenCLRuntime::BuildKernel(const cl::Kernel &kernel, const std::string &program_name,
const std::string &kernel_name, const std::vector<std::string> &build_options_ext,
TypeId data_type) {
std::string build_option = default_build_option_;
if (fp16_enable_ && data_type != kNumberTypeInt32) {
build_option +=
@ -399,7 +402,7 @@ int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_na
}
cl_int ret;
kernel = cl::Kernel(program, kernel_name.c_str(), &ret);
const_cast<cl::Kernel &>(kernel) = cl::Kernel(program, kernel_name.c_str(), &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << kernel_name << " Kernel create failed:" << CLErrorCode(ret);
return RET_ERROR;

View File

@ -27,8 +27,8 @@ j* you may not use this file except in compliance with the License.
#include <type_traits>
#include "dtype/type_id.h"
#include "src/common/log_adapter.h"
#include "src/runtime/opencl/opencl_wrapper.h"
#include "src/runtime/opencl/opencl_allocator.h"
#include "src/runtime/gpu/opencl/opencl_wrapper.h"
#include "src/runtime/gpu/opencl/opencl_allocator.h"
#include "schema/gpu_cache_generated.h"
namespace mindspore::lite::opencl {
@ -76,8 +76,8 @@ class OpenCLRuntime {
cl_device_svm_capabilities GetSVMCapabilities() const { return svm_enable_ ? svm_capabilities_ : 0; }
template <typename T>
typename std::enable_if<std::is_pointer<T>::value, cl_int>::type SetKernelArg(cl::Kernel &kernel, uint32_t index,
const T value,
typename std::enable_if<std::is_pointer<T>::value, cl_int>::type SetKernelArg(const cl::Kernel &kernel,
uint32_t index, const T value,
const MemType mem_type = MemType::IMG) {
switch (mem_type) {
case MemType::BUF: {
@ -88,7 +88,7 @@ class OpenCLRuntime {
}
cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetBuffer(value));
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << buffer << ", host_ptr: " << value;
return kernel.setArg(index, *buffer);
return const_cast<cl::Kernel &>(kernel).setArg(index, *buffer);
}
case MemType::IMG: {
cl::Image2D *image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(value));
@ -96,10 +96,10 @@ class OpenCLRuntime {
MS_LOG(WARNING) << "Can't get Image2D, try to use Buffer. Please confirm the buffer type.";
cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetBuffer(value));
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << buffer << ", host_ptr: " << value;
return kernel.setArg(index, *buffer);
return const_cast<cl::Kernel &>(kernel).setArg(index, *buffer);
}
MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Image2D " << image << ", host_ptr: " << value;
return kernel.setArg(index, *image);
return const_cast<cl::Kernel &>(kernel).setArg(index, *image);
}
default:
MS_LOG(ERROR) << "Unsupported opencl memory type: " << static_cast<int>(mem_type);
@ -109,8 +109,8 @@ class OpenCLRuntime {
template <typename T>
typename std::enable_if<!std::is_pointer<T>::value, cl_int>::type SetKernelArg(
cl::Kernel &kernel, uint32_t index, const T value, const MemType mem_type = MemType::IMG) {
return kernel.setArg(index, value);
const cl::Kernel &kernel, uint32_t index, const T value, const MemType mem_type = MemType::IMG) {
return const_cast<cl::Kernel &>(kernel).setArg(index, value);
}
cl::Program CreateProgramFromIL(const std::vector<char> &binary, const std::string &flag);
@ -118,7 +118,7 @@ class OpenCLRuntime {
cl::Kernel GetKernelFromBinary(const std::string &kernel_name);
std::vector<unsigned char> GetProgramBinary(const cl::Program &program);
bool LoadSource(const std::string &program_name, const std::string &source);
int BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name,
int BuildKernel(const cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name,
const std::vector<std::string> &build_options_ext = {}, TypeId data_type = kNumberTypeFloat32);
int RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local,
cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr);
@ -160,8 +160,8 @@ class OpenCLRuntime {
bool LoadProgram(const std::string &program_name, cl::Program *program);
bool BuildProgram(const std::string &build_options, const cl::Program &program);
int InitGPUDevice(std::vector<cl::Platform> &platforms);
int InitQueue(std::vector<cl::Platform> &platforms);
int InitGPUDevice(std::vector<cl::Platform> *platforms);
int InitQueue(std::vector<cl::Platform> *platforms);
private:
static InitState init_state_;

View File

@ -16,7 +16,7 @@
#ifdef USE_OPENCL_WRAPPER
#include "src/runtime/opencl/opencl_wrapper.h"
#include "src/runtime/gpu/opencl/opencl_wrapper.h"
#include <dlfcn.h>
#include <memory>
#include <string>

View File

@ -29,7 +29,7 @@ namespace mindspore::lite::opencl {
bool LoadOpenCLLibrary(void **handle_ptr);
bool UnLoadOpenCLLibrary(void *handle);
// get platfrom id
// get platform id
using clGetPlatformIDsFunc = cl_int (*)(cl_uint, cl_platform_id *, cl_uint *);
// get platform info
using clGetPlatformInfoFunc = cl_int (*)(cl_platform_id, cl_platform_info, size_t, void *, size_t *);
@ -74,8 +74,7 @@ using clEnqueueMapBufferFunc = void *(*)(cl_command_queue, cl_mem, cl_bool, cl_m
const cl_event *, cl_event *, cl_int *);
using clEnqueueMapImageFunc = void *(*)(cl_command_queue, cl_mem, cl_bool, cl_map_flags, const size_t *, const size_t *,
size_t *, size_t *, cl_uint, const cl_event *, cl_event *, cl_int *);
using clCreateCommandQueueFunc = cl_command_queue(CL_API_CALL *)(cl_context, cl_device_id, cl_command_queue_properties,
cl_int *);
using clCreateCommandQueueFunc = cl_command_queue (*)(cl_context, cl_device_id, cl_command_queue_properties, cl_int *);
using clGetCommandQueueInfoFunc = cl_int (*)(cl_command_queue, cl_command_queue_info, size_t, void *, size_t *);
using clReleaseCommandQueueFunc = cl_int (*)(cl_command_queue);
using clCreateProgramWithBinaryFunc = cl_program (*)(cl_context, cl_uint, const cl_device_id *, const size_t *,
@ -89,10 +88,10 @@ using clGetProgramInfoFunc = cl_int (*)(cl_program, cl_program_info, size_t, voi
using clCreateKernelFunc = cl_kernel (*)(cl_program, const char *, cl_int *);
using clRetainKernelFunc = cl_int (*)(cl_kernel kernel);
using clCreateBufferFunc = cl_mem (*)(cl_context, cl_mem_flags, size_t, void *, cl_int *);
using clCreateImage2DFunc = cl_mem(CL_API_CALL *)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t,
size_t, void *, cl_int *);
using clCreateImage3DFunc = cl_mem(CL_API_CALL *)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t,
size_t, size_t, size_t, void *, cl_int *);
using clCreateImage2DFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, size_t,
void *, cl_int *);
using clCreateImage3DFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, size_t,
size_t, size_t, void *, cl_int *);
using clCreateProgramWithSourceFunc = cl_program (*)(cl_context, cl_uint, const char **, const size_t *, cl_int *);
using clReleaseKernelFunc = cl_int (*)(cl_kernel kernel);
using clGetDeviceInfoFunc = cl_int (*)(cl_device_id, cl_device_info, size_t, void *, size_t *);
@ -105,11 +104,10 @@ using clGetEventInfoFunc = cl_int (*)(cl_event event, cl_event_info param_name,
using clGetEventProfilingInfoFunc = cl_int (*)(cl_event event, cl_profiling_info param_name, size_t param_value_size,
void *param_value, size_t *param_value_size_ret);
using clGetImageInfoFunc = cl_int (*)(cl_mem, cl_image_info, size_t, void *, size_t *);
using clEnqueueCopyBufferToImageFunc = cl_int(CL_API_CALL *)(cl_command_queue, cl_mem, cl_mem, size_t, const size_t *,
const size_t *, cl_uint, const cl_event *, cl_event *);
using clEnqueueCopyImageToBufferFunc = cl_int(CL_API_CALL *)(cl_command_queue, cl_mem, cl_mem, const size_t *,
const size_t *, size_t, cl_uint, const cl_event *,
cl_event *);
using clEnqueueCopyBufferToImageFunc = cl_int (*)(cl_command_queue, cl_mem, cl_mem, size_t, const size_t *,
const size_t *, cl_uint, const cl_event *, cl_event *);
using clEnqueueCopyImageToBufferFunc = cl_int (*)(cl_command_queue, cl_mem, cl_mem, const size_t *, const size_t *,
size_t, cl_uint, const cl_event *, cl_event *);
#if CL_TARGET_OPENCL_VERSION >= 120
using clRetainDeviceFunc = cl_int (*)(cl_device_id);
using clReleaseDeviceFunc = cl_int (*)(cl_device_id);
@ -127,11 +125,11 @@ using clEnqueueSVMMapFunc = cl_int (*)(cl_command_queue, cl_bool, cl_map_flags,
using clEnqueueSVMUnmapFunc = cl_int (*)(cl_command_queue, void *, cl_uint, const cl_event *, cl_event *);
using clSetKernelArgSVMPointerFunc = cl_int (*)(cl_kernel, cl_uint, const void *);
// opencl 2.0 can get sub group info and wave size.
using clGetKernelSubGroupInfoKHRFunc = cl_int(CL_API_CALL *)(cl_kernel, cl_device_id, cl_kernel_sub_group_info, size_t,
const void *, size_t, void *, size_t *);
using clCreateCommandQueueWithPropertiesFunc = cl_command_queue(CL_API_CALL *)(cl_context, cl_device_id,
const cl_queue_properties *, cl_int *);
using clGetExtensionFunctionAddressFunc = void *(CL_API_CALL *)(const char *);
using clGetKernelSubGroupInfoKHRFunc = cl_int (*)(cl_kernel, cl_device_id, cl_kernel_sub_group_info, size_t,
const void *, size_t, void *, size_t *);
using clCreateCommandQueueWithPropertiesFunc = cl_command_queue (*)(cl_context, cl_device_id,
const cl_queue_properties *, cl_int *);
using clGetExtensionFunctionAddressFunc = void *(*)(const char *);
#endif
#define CL_DECLARE_FUNC_PTR(func) extern func##Func func

View File

@ -0,0 +1,6 @@
file(GLOB_RECURSE CUDA_KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel/*.cc)
add_library(cuda_kernel_mid OBJECT ${CUDA_KERNEL_SRC})
add_dependencies(cuda_kernel_mid fbs_src)

View File

@ -0,0 +1,22 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/cuda/cuda_kernel.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {} // namespace mindspore::kernel

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CUDA_KERNEL_H_
#define MINDSPORE_LITE_SRC_CUDA_KERNEL_H_
#include <vector>
#include <set>
#include <map>
#include <string>
#include "src/lite_kernel.h"
#include "include/errorcode.h"
#include "src/runtime/gpu/gpu_runtime.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_CUDA_KERNEL_H_

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/cuda/cuda_subgraph.h"
#include <set>
#include "include/errorcode.h"
#include "src/common/utils.h"
namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
} // namespace mindspore::kernel

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUDA_KERNEL_CUDA_SUBGRAPH_KERNEL_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUDA_KERNEL_CUDA_SUBGRAPH_KERNEL_H_
#include <set>
#include <vector>
#include "src/sub_graph_kernel.h"
namespace mindspore::kernel {} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUDA_KERNEL_CUDA_SUBGRAPH_KERNEL_H_

View File

@ -1,4 +1,7 @@
file(GLOB_RECURSE OPENCL_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernel/*.cc)
add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC})
add_dependencies(opencl_kernel_mid fbs_src)
if(${SUPPORT_GPU} STREQUAL opencl)
file(GLOB_RECURSE OPENCL_KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel/*.cc)
add_library(opencl_kernel_mid OBJECT ${OPENCL_KERNEL_SRC})
add_dependencies(opencl_kernel_mid fbs_src)
endif()

View File

@ -23,7 +23,7 @@
#include "src/runtime/kernel/opencl/kernel/conv2d.h"
#include "src/runtime/kernel/opencl/kernel/fusion_eltwise.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/opencl/opencl_executor.h"
#include "src/runtime/gpu/opencl/opencl_executor.h"
#include "include/errorcode.h"
#include "schema/ops_generated.h"
#include "src/common/utils.h"

View File

@ -24,7 +24,7 @@
#include <string>
#include "src/lite_kernel.h"
#include "include/errorcode.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"
#include "mindspore/lite/src/dequant.h"
#include "src/runtime/kernel/opencl/utils.h"

View File

@ -18,7 +18,7 @@
#include <set>
#include <map>
#include <string>
#include "src/runtime/opencl/opencl_executor.h"
#include "src/runtime/gpu/opencl/opencl_executor.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "include/errorcode.h"
#include "src/common/utils.h"

View File

@ -20,8 +20,8 @@
#include <set>
#include <vector>
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/opencl/opencl_allocator.h"
#include "src/runtime/opencl/opencl_executor.h"
#include "src/runtime/gpu/opencl/opencl_allocator.h"
#include "src/runtime/gpu/opencl/opencl_executor.h"
#include "src/sub_graph_kernel.h"
namespace mindspore::kernel {

View File

@ -1,11 +0,0 @@
set(OPENCL_RUNTIME_SRC
${CMAKE_CURRENT_SOURCE_DIR}/opencl_allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/opencl_executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/opencl_allocator.h
${CMAKE_CURRENT_SOURCE_DIR}/opencl_kernel.h
${CMAKE_CURRENT_SOURCE_DIR}/opencl_runtime.cc
${CMAKE_CURRENT_SOURCE_DIR}/opencl_runtime.h
${CMAKE_CURRENT_SOURCE_DIR}/opencl_wrapper.cc
${CMAKE_CURRENT_SOURCE_DIR}/opencl_wrapper.h
)

View File

@ -28,9 +28,9 @@
#include "src/kernel_registry.h"
#include "src/sub_graph_kernel.h"
#include "src/dequant.h"
#if SUPPORT_GPU
#if GPU_OPENCL
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"
#endif
#if SUPPORT_NPU
#include "src/runtime/agent/npu/subgraph_npu_kernel.h"
@ -462,7 +462,7 @@ kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel
std::vector<kernel::LiteKernel *> input_kernels = kernel::LiteKernelUtil::SubgraphInputNodes(kernels);
std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputNodes(kernels);
if (type == kernel::kGpuSubGraph) {
#if SUPPORT_GPU
#if GPU_OPENCL
auto sub_kernel = new (std::nothrow)
kernel::OpenCLSubGraph(input_tensors, output_tensors, input_kernels, output_kernels, kernels, context_);
if (sub_kernel == nullptr) {
@ -470,6 +470,8 @@ kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel
return nullptr;
}
return sub_kernel;
#elif GPU_VULKAN
return nullptr;
#else
return nullptr;
#endif

View File

@ -89,7 +89,7 @@ if("${X86_64_SIMD}" STREQUAL "avx")
endif()
### gpu kernel
if(SUPPORT_GPU)
if(SUPPORT_GPU STREQUAL opencl)
file(GLOB GPU_KERNEL_OP_SRC
${LITE_DIR}/src/runtime/kernel/opencl/kernel/*.cc
)
@ -102,6 +102,15 @@ if(SUPPORT_GPU)
${LITE_DIR}/src/runtime/kernel/opencl/utils.cc
)
endif()
if(SUPPORT_GPU STREQUAL vulkan)
file(GLOB GPU_KERNEL_OP_SRC
${LITE_DIR}/src/runtime/kernel/vulkan/kernel/*.cc
)
set(KERNEL_OP_SRC
${KERNEL_OP_SRC}
${GPU_KERNEL_OP_SRC}
)
endif()
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
if(ENABLE_CONVERTER)
@ -150,20 +159,28 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/errorcode.cc
)
### gpu runtime
if(SUPPORT_GPU)
include_directories(${TOP_DIR}/third_party/OpenCL-Headers)
include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include)
set(OPENCL_RUNTIME_SRC
${LITE_DIR}/src/runtime/opencl/opencl_allocator.cc
${LITE_DIR}/src/runtime/opencl/opencl_executor.cc
${LITE_DIR}/src/runtime/opencl/opencl_runtime.cc
${LITE_DIR}/src/runtime/opencl/opencl_wrapper.cc
if(SUPPORT_GPU STREQUAL opencl)
include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-headers-src)
include_directories(${CMAKE_BINARY_DIR}/_deps/opencl-clhpp-src/include)
file(GLOB_RECURSE OPENCL_RUNTIME_SRC
${LITE_DIR}/src/runtime/gpu/opencl/*.cc
)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${OPENCL_RUNTIME_SRC}
)
endif()
if(SUPPORT_GPU STREQUAL vulkan)
include_directories(${LITE_DIR}/build/_deps/vulkan-headers-src/include)
file(GLOB VULKAN_RUNTIME_SRC
${LITE_DIR}/src/runtime/gpu/*.cc
${LITE_DIR}/src/runtime/vulkan/*.cc
)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${VULKAN_RUNTIME_SRC}
)
endif()
### converter
if(ENABLE_CONVERTER)
add_definitions(-DPRIMITIVE_WRITEABLE)
@ -286,7 +303,7 @@ else()
)
endif()
if(SUPPORT_GPU)
if(SUPPORT_GPU STREQUAL opencl)
file(GLOB_RECURSE TEST_CASE_KERNEL_GPU_SRC
${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc
)

View File

@ -17,7 +17,6 @@
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h"

View File

@ -17,7 +17,6 @@
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h"
using mindspore::lite::Tensor;

View File

@ -18,7 +18,6 @@
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h"