!15323 sub graph split support gpu

From: @ling_qiao_min
Reviewed-by: @chenzupeng,@zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-04-17 18:09:54 +08:00 committed by Gitee
commit 858924a528
32 changed files with 243 additions and 138 deletions

View File

@ -35,7 +35,7 @@ option(ENABLE_VERBOSE "" off)
option(ENABLE_SSE "if x86_64 support SSE instruction set" off)
option(ENABLE_AVX "if x86_64 support SSE instruction set" off)
option(ENABLE_MINDRT "if support mindrt" on)
option(SUBGRAPH_SPLIT "if support sub graph split" off)
option(SUBGRAPH_SPLIT "if support sub graph split" on)
set(DIR_PREFIX mindspore-lite)
set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})

View File

@ -133,7 +133,6 @@ set(LITE_SRC
${LITE_DIR}/src/common/tensor_util.cc
${LITE_DIR}/src/runtime/infer_manager.cc
${LITE_DIR}/src/lite_model.cc
${LITE_DIR}/src/sub_graph_split.cc
${LITE_DIR}/src/tensorlist.cc
${LITE_DIR}/src/tensor.cc
${LITE_DIR}/src/weight_decoder.cc

View File

@ -33,15 +33,7 @@
#include "include/context.h"
namespace mindspore::kernel {
enum KERNEL_ARCH {
kCPU,
kGPU,
kAPU,
kNPU,
kALL, /* Support GPU NPU CPU */
kKernelArch_MIN = kCPU,
kKernelArch_MAX = kALL
};
enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU };
struct KernelKey {
KERNEL_ARCH arch;

View File

@ -54,19 +54,18 @@ void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) {
return;
}
void LiteOpActor::AddResultIndex(size_t index) {
results_index_.push_back(index);
return;
}
void LiteOpActor::SetOutputData(OpContext<Tensor> *context) {
auto size = context->outputData_->size();
MS_ASSERT(size == context->results_->size());
for (size_t i = 0; i < size; i++) {
auto outputData = context->outputData_->at(i);
if (GetAID() == outputData->op_id_) {
outputData->data_ = kernel_->out_tensors()[outputData->index_];
context->SetResult(i, RET_OK);
}
for (auto index : results_index_) {
context->SetResult(index, RET_OK);
}
}
int MindrtInit() { return mindspore::Initialize("tcp://127.0.0.1:8080", "", "", "", 1); }
int MindrtInit() { return mindspore::Initialize("tcp://127.0.0.1:8080", "", "", "", 2); }
void MindrtTerminate(std::vector<std::shared_ptr<LiteOpActor>> actor_list) {
for (auto actor : actor_list) {

View File

@ -42,6 +42,10 @@ class LiteOpActor : public OpActor<lite::Tensor> {
if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {
return;
}
Context *ctx = const_cast<Context *>(kernel_->context());
if (kernel_->desc().arch == kernel::kCPU) {
BindThreads(static_cast<lite::InnerContext *>(ctx)->thread_pool_, true, 2);
}
auto ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
if (ret != RET_OK) {
@ -51,6 +55,9 @@ class LiteOpActor : public OpActor<lite::Tensor> {
}
input_op_datas_.erase(op_uuid);
AsyncOutput(context);
if (kernel_->desc().arch == kernel::kCPU) {
BindThreads(static_cast<lite::InnerContext *>(ctx)->thread_pool_, true, 2);
}
SetOutputData(context);
}
void Init() {
@ -82,11 +89,15 @@ class LiteOpActor : public OpActor<lite::Tensor> {
return ret;
}
public:
void AddResultIndex(size_t index);
private:
void SetOutputData(OpContext<Tensor> *context);
void AsyncOutput(OpContext<Tensor> *context);
kernel::LiteKernel *kernel_;
std::vector<size_t> results_index_;
};
int MindrtInit();

View File

@ -51,6 +51,7 @@ int MindrtExecutor::Prepare(const std::vector<kernel::LiteKernel *> &kernels) {
for (size_t j = 0; j < outTensorSize; j++) {
auto data =
std::make_shared<OpData<Tensor>>(opActors_[i]->GetAID(), kernels[i]->out_tensors()[j], static_cast<int>(j));
opActors_[i]->AddResultIndex(outputData_.size());
outputData_.emplace_back(data);
}
}

View File

@ -156,26 +156,26 @@ int OpenCLAllocator::GetImgDtypeSize(const ImageSize &img_size) {
void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const ImageSize &img_size) {
auto svm_capabilities = ocl_runtime_->GetSVMCapabilities();
auto enable_arm_import_memory = ocl_runtime_->isExtensionEnable(EXT_ARM_IMPORT_MEMORY_HOST);
if (mem_type == MemType::SHARED && !enable_arm_import_memory) {
mem_type = MemType::BUF;
}
if (mem_type == MemType::IMG) {
size = GetImgDtypeSize(img_size);
}
if (size > ocl_runtime_->GetMaxAllocSize()) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;
return nullptr;
}
Lock();
void *host_ptr = MinimumFit(mem_type, size, img_size);
if (host_ptr != nullptr && data == nullptr) {
UnLock();
return host_ptr;
}
UNLOCK_AND_RETURN_NULL(host_ptr != nullptr && data == nullptr, host_ptr);
total_size_ += size;
const uint64_t max_size = ocl_runtime_->GetGlobalMemSize() * 0.8;
if (total_size_ >= max_size) {
UnLock();
MS_LOG(ERROR) << "Mem pool out of max_size, total size: " << total_size_ << ", max size: " << max_size;
return nullptr;
}
UNLOCK_AND_RETURN_NULL(total_size_ >= max_size, nullptr);
cl::Buffer *buffer = nullptr;
cl::Image2D *image = nullptr;
cl_mem_flags flags = CL_MEM_READ_WRITE;
@ -184,27 +184,32 @@ void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const
flags |= (svm_capabilities & CL_DEVICE_SVM_ATOMICS) ? CL_MEM_SVM_ATOMICS : 0;
host_ptr = clSVMAlloc((*ocl_runtime_->Context())(), flags, size, 0);
} else {
flags |= (data == nullptr) ? CL_MEM_ALLOC_HOST_PTR : CL_MEM_COPY_HOST_PTR;
if (mem_type == MemType::BUF || data == nullptr) {
host_ptr = CreateBuffer(size, data, flags, &buffer);
if (host_ptr == nullptr) {
UnLock();
return nullptr;
if (mem_type == MemType::SHARED) {
size = UP_ROUND(size, ocl_runtime_->GetCacheLineSize());
host_ptr = malloc(size);
UNLOCK_AND_RETURN_NULL(host_ptr == nullptr, nullptr);
buffer = ocl_runtime_->CreateSharedMemoryBuffer(size, host_ptr);
} else {
flags |= (data == nullptr) ? CL_MEM_ALLOC_HOST_PTR : CL_MEM_COPY_HOST_PTR;
if (mem_type == MemType::BUF || data == nullptr) {
host_ptr = CreateBuffer(size, data, flags, &buffer);
UNLOCK_AND_RETURN_NULL(host_ptr == nullptr, nullptr);
}
}
if (mem_type == MemType::IMG) {
void *host_ptr_im = CreateImage2D(size, img_size, data, flags, data != nullptr, &buffer, &image);
if (data != nullptr && host_ptr_im == nullptr) {
UnLock();
return nullptr;
if (mem_type == MemType::IMG) {
void *host_ptr_im = CreateImage2D(size, img_size, data, flags, data != nullptr, &buffer, &image);
UNLOCK_AND_RETURN_NULL(data != nullptr && host_ptr_im == nullptr, nullptr);
host_ptr = (data != nullptr) ? host_ptr_im : host_ptr;
}
host_ptr = (data != nullptr) ? host_ptr_im : host_ptr;
}
}
MemBuf *mem_buf = new (std::nothrow) MemBuf;
if (mem_buf == nullptr) {
delete buffer;
delete image;
if (mem_type == MemType::SHARED) {
free(host_ptr);
}
UnLock();
return nullptr;
}
@ -216,7 +221,9 @@ void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const
mem_buf->img_size_ = img_size;
allocated_list_[host_ptr] = mem_buf;
UnLock();
std::string type_name = mem_type == MemType::BUF ? "buffer" : "Image2D";
std::string type_name = (mem_type == MemType::BUF) ? "buffer" : "Image2D";
type_name = (mem_type == MemType::SHARED) ? "shared" : type_name;
MS_LOG(DEBUG) << "Malloc a new " << type_name << ". size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_
<< ", device addr: " << mem_buf->device_ptr_ << ", image_addr: " << image
<< ", total size: " << total_size_;
@ -306,6 +313,10 @@ void OpenCLAllocator::ClearMemList(T *list) {
delete image;
it->second->image_ptr_ = nullptr;
}
if (it->second->mem_type_ == MemType::SHARED) {
free(it->second->host_ptr_);
it->second->host_ptr_ = nullptr;
}
}
delete it->second;
}
@ -351,12 +362,18 @@ void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue,
}
MemBuf *mem_buf = it->second;
MS_ASSERT(mem_buf);
if (mem_buf->mem_type_ == MemType::SHARED) {
UnLock();
MS_LOG(WARNING) << "Host ptr " << host_ptr << " no need map";
return host_ptr;
}
void *new_host_ptr{nullptr};
if (mem_buf->mem_type_ == MemType::BUF) {
cl::Buffer *buffer = static_cast<cl::Buffer *>(mem_buf->device_ptr_);
MS_ASSERT(buffer);
new_host_ptr = ocl_runtime_->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync);
} else {
} else if (mem_buf->mem_type_ == MemType::IMG) {
std::vector<size_t> region{mem_buf->img_size_.width, mem_buf->img_size_.height, 1};
cl::Image2D *image = static_cast<cl::Image2D *>(mem_buf->image_ptr_);
MS_ASSERT(image);

View File

@ -28,9 +28,17 @@
#include "CL/cl2.hpp"
namespace mindspore::lite::opencl {
#define UNLOCK_AND_RETURN_NULL(condition, ptr) \
do { \
if (condition) { \
UnLock(); \
return (ptr); \
} \
} while (0)
class OpenCLRuntime;
enum class MemType : char { BUF, IMG };
enum class MemType : char { BUF, IMG, SHARED };
struct ImageSize {
size_t width = 0;
size_t height = 0;
@ -45,8 +53,11 @@ class OpenCLAllocator : public mindspore::Allocator {
explicit OpenCLAllocator(OpenCLRuntime *ocl_runtime);
~OpenCLAllocator() override;
void SetContext(const AllocatorContext &ctx) override;
void *Malloc(size_t size, MemType type) { return _Malloc(type, nullptr, size); }
// malloc shared
void *Malloc(size_t size) override { return _Malloc(MemType::SHARED, nullptr, size); }
// malloc buffer
void *Malloc(size_t size) override { return _Malloc(MemType::BUF, nullptr, size); }
void *Malloc(size_t size, void *data) { return _Malloc(MemType::BUF, data, size); }
// malloc image
void *Malloc(const ImageSize &img_size, void *data = nullptr) { return _Malloc(MemType::IMG, data, 0, img_size); }

View File

@ -167,6 +167,8 @@ int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> *platforms) {
max_alloc_size_ = device_->getInfo<CL_DEVICE_MAX_MEM_ALLOC_SIZE>();
max_image2d_width_ = device_->getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
max_image2d_height_ = device_->getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
supported_extensions_ = std::string(device_->getInfo<CL_DEVICE_EXTENSIONS>());
cache_line_size_ = device_->getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
MS_LOG(INFO) << "Address space bits: " << device_->getInfo<CL_DEVICE_ADDRESS_BITS>();
MS_LOG(INFO) << "Global Mem Size: " << global_memery_size_;
MS_LOG(INFO) << "Global Mem Cache Size: " << global_memery_cachesize_;
@ -757,4 +759,19 @@ void OpenCLRuntime::StoreCache() {
MS_LOG(INFO) << "store opencl cache ok, size=" << fbb->GetSize();
}
cl::Buffer *OpenCLRuntime::CreateSharedMemoryBuffer(size_t size, void *host_ptr) {
cl_int error = CL_SUCCESS;
cl_mem cl_buffer = clImportMemoryARM(context_->get(), CL_MEM_READ_WRITE, NULL, host_ptr, size, &error);
if (error != CL_SUCCESS) {
MS_LOG(ERROR) << "Create OpenCL shared memory failed for" << CLErrorCode(error);
return nullptr;
}
cl::Buffer *buffer = new (std::nothrow) cl::Buffer(cl_buffer, false);
if (buffer == nullptr) {
MS_LOG(ERROR) << "New OpenCL Buffer failed";
return nullptr;
}
return buffer;
}
} // namespace mindspore::lite::opencl

View File

@ -28,6 +28,7 @@ j* you may not use this file except in compliance with the License.
#include "src/runtime/gpu/opencl/opencl_wrapper.h"
#include "src/runtime/gpu/opencl/opencl_allocator.h"
#include "schema/gpu_cache_generated.h"
#define EXT_ARM_IMPORT_MEMORY_HOST "cl_arm_import_memory_host"
namespace mindspore::lite::opencl {
@ -151,6 +152,9 @@ class OpenCLRuntime {
bool isProfiling() const { return profiling_; }
void SetProfiling(bool profiling) { profiling_ = profiling; }
bool isExtensionEnable(std::string ext) { return supported_extensions_.find(ext) != std::string::npos; }
cl::Buffer *CreateSharedMemoryBuffer(size_t size, void *host_ptr);
uint GetCacheLineSize() const { return cache_line_size_; }
private:
static OpenCLRuntime *GetInstance();
@ -196,6 +200,8 @@ class OpenCLRuntime {
bool profiling_{true};
#else
bool profiling_{false};
std::string supported_extensions_{""};
uint cache_line_size_{1};
#endif
// for cache
private:

View File

@ -102,6 +102,7 @@ bool LoadLibraryFromPath(const std::string &library_path, void **handle_ptr) {
LOAD_OPENCL_FUNCTION_PTR(clCreateProgramWithSource);
LOAD_OPENCL_FUNCTION_PTR(clCreateBuffer);
LOAD_OPENCL_FUNCTION_PTR(clCreateImage2D);
LOAD_OPENCL_FUNCTION_PTR(clImportMemoryARM);
LOAD_OPENCL_FUNCTION_PTR(clCreateImage3D);
LOAD_OPENCL_FUNCTION_PTR(clRetainKernel);
LOAD_OPENCL_FUNCTION_PTR(clCreateKernel);
@ -192,6 +193,7 @@ CL_DEFINE_FUNC_PTR(clReleaseKernel);
CL_DEFINE_FUNC_PTR(clCreateProgramWithSource);
CL_DEFINE_FUNC_PTR(clCreateBuffer);
CL_DEFINE_FUNC_PTR(clCreateImage2D);
CL_DEFINE_FUNC_PTR(clImportMemoryARM);
CL_DEFINE_FUNC_PTR(clCreateImage3D);
CL_DEFINE_FUNC_PTR(clRetainKernel);
CL_DEFINE_FUNC_PTR(clCreateKernel);

View File

@ -90,6 +90,7 @@ using clRetainKernelFunc = cl_int (*)(cl_kernel kernel);
using clCreateBufferFunc = cl_mem (*)(cl_context, cl_mem_flags, size_t, void *, cl_int *);
using clCreateImage2DFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, size_t,
void *, cl_int *);
using clImportMemoryARMFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, void *, ssize_t, cl_int *);
using clCreateImage3DFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, size_t,
size_t, size_t, void *, cl_int *);
using clCreateProgramWithSourceFunc = cl_program (*)(cl_context, cl_uint, const char **, const size_t *, cl_int *);
@ -143,6 +144,7 @@ CL_DECLARE_FUNC_PTR(clReleaseKernel);
CL_DECLARE_FUNC_PTR(clCreateProgramWithSource);
CL_DECLARE_FUNC_PTR(clCreateBuffer);
CL_DECLARE_FUNC_PTR(clCreateImage2D);
CL_DECLARE_FUNC_PTR(clImportMemoryARM);
CL_DECLARE_FUNC_PTR(clCreateImage3D);
CL_DECLARE_FUNC_PTR(clRetainKernel);
CL_DECLARE_FUNC_PTR(clCreateKernel);

View File

@ -70,6 +70,7 @@ void GroupConvolutionBaseCPUKernel::FreeSubKernel() {
delete sub_conv;
sub_conv = nullptr;
}
group_convs_.clear();
}
int GroupConvolutionBaseCPUKernel::PreProcess() {

View File

@ -113,6 +113,7 @@ void GroupConvCreator::FreeGroupConvs() {
}
delete sub_conv;
}
group_convs_.clear();
}
int GroupConvCreator::NewInputTensor(std::vector<lite::Tensor *> *tensors) {

View File

@ -133,8 +133,8 @@ void ArgMinMaxOpenCLKernel::SetGlobalLocal() {
int ArgMinMaxOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
int dtype_size = ocl_runtime_->GetFp16Enable() ? sizeof(int16_t) : sizeof(float);
buff_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * dtype_size);
ids_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * sizeof(int32_t));
buff_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * dtype_size, lite::opencl::MemType::BUF);
ids_ = allocator->Malloc(in_tensors_[0]->ElementsNum() * sizeof(int32_t), lite::opencl::MemType::BUF);
return RET_OK;
}

View File

@ -90,10 +90,10 @@ int BatchNormOpenCLKernel::Initweight() {
auto weight_tensor = in_tensors_.at(1);
size_t weight_size = img_info.OriginSize;
// allocated memory for weight and init value
scale_ = allocator->Malloc(weight_size);
offset_ = allocator->Malloc(weight_size);
mean_ = allocator->Malloc(weight_size);
variance_ = allocator->Malloc(weight_size);
scale_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
offset_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
mean_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
variance_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
allocator->MapBuffer(scale_, CL_MAP_WRITE, nullptr, true);
allocator->MapBuffer(offset_, CL_MAP_WRITE, nullptr, true);

View File

@ -254,7 +254,7 @@ void Conv2DOpenCLKernel::InitFilter() {
packed_filter_ = allocator->Malloc({width, height, dtype});
} else {
size = UP_DIV(CO_SLICES_, Ogroup) * KH_ * KW_ * CI_SLICES_ * Ogroup * CI_TILE * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size);
packed_filter_ = allocator->Malloc(size, lite::opencl::MemType::BUF);
}
// rearrange filter
@ -287,7 +287,7 @@ void Conv2DOpenCLKernel::InitBias() {
// align bias from C to C4
auto bias_tensor = in_tensors_.at(2);
size_t packed_bias_size = UP_ROUND(CO_SLICES_, block_size_.C) * CO_TILE * sizeof_FLT_;
packed_bias_ = allocator->Malloc(packed_bias_size);
packed_bias_ = allocator->Malloc(packed_bias_size, lite::opencl::MemType::BUF);
allocator->MapBuffer(packed_bias_, CL_MAP_WRITE, nullptr, true);
memset(packed_bias_, 0x00, packed_bias_size);

View File

@ -144,7 +144,7 @@ int Conv2dTransposeOpenCLKernel::InitFilter() {
// IHWO to OHWI4(I)4(O)(converter format is IHWO)
// init padWeight_(buffer mem)
padWeight_ = allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size);
padWeight_ = allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size, lite::opencl::MemType::BUF);
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
memset(padWeight_, 0x00, div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size);
auto origin_weight = in_tensors_.at(kWeightIndex)->data_c();

View File

@ -133,7 +133,8 @@ int FullConnectionOpenCLKernel::InitFilter() {
int co4 = UP_DIV(CO_, C4NUM);
int nhw_remainder = intensor_shape.N * intensor_shape.H * intensor_shape.W / N_;
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size);
padWeight_ = allocator->Malloc(nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size,
lite::opencl::MemType::BUF);
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
auto padWeightFp32 = reinterpret_cast<float *>(padWeight_);
auto padWeightFp16 = reinterpret_cast<float16_t *>(padWeight_);

View File

@ -183,7 +183,7 @@ int FusionEltwiseOpenCLKernel::InitWeights() {
auto tensor_info = GpuTensorInfo(tensor);
size_t num = tensor_info.ElementsNum;
size_t size = tensor_info.Image2DSize;
void *buffer = allocator->Malloc(size);
void *buffer = allocator->Malloc(size, lite::opencl::MemType::BUF);
allocator->MapBuffer(buffer, CL_MAP_WRITE, nullptr, true);
memset(buffer, 0x00, size);
if (tensor->data_type() == kNumberTypeFloat16) {

View File

@ -129,7 +129,8 @@ int GatherOpenCLKernel::ConvertTensorToweight() {
auto allocator = ocl_runtime_->GetAllocator();
auto indices_tensor = in_tensors_.at(1);
auto indices_num = indices_tensor->ElementsNum();
indices_data_ = reinterpret_cast<int32_t *>(allocator->Malloc(sizeof(int32_t) * indices_num));
indices_data_ =
reinterpret_cast<int32_t *>(allocator->Malloc(sizeof(int32_t) * indices_num), lite::opencl::MemType::BUF);
allocator->MapBuffer(indices_data_, CL_MAP_WRITE, nullptr, true);
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
@ -154,7 +155,8 @@ int GatherOpenCLKernel::InitWeights() {
auto indices_tensor = in_tensors_.at(1);
auto indices_num = indices_tensor->ElementsNum();
auto allocator = ocl_runtime_->GetAllocator();
indices_data_ = reinterpret_cast<int32_t *>(allocator->Malloc(sizeof(int32_t) * indices_num));
indices_data_ =
reinterpret_cast<int32_t *>(allocator->Malloc(sizeof(int32_t) * indices_num), lite::opencl::MemType::BUF);
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR;

View File

@ -106,8 +106,8 @@ int LayerNormOpenCLKernel::Initweight() {
auto weight_tensor = in_tensors_.at(1);
size_t weight_size = img_info.Image2DSize;
// allocated memory for weight and init value
gamma_ = allocator->Malloc(weight_size);
beta_ = allocator->Malloc(weight_size);
gamma_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
beta_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
allocator->MapBuffer(gamma_, CL_MAP_WRITE, nullptr, true);
allocator->MapBuffer(beta_, CL_MAP_WRITE, nullptr, true);
memset(gamma_, 0x01, weight_size);
@ -164,8 +164,8 @@ int LayerNormOpenCLKernel::Prepare() {
}
size_t size_dtype = use_fp16_enable_ ? sizeof(float16_t) : sizeof(float);
mean_size *= size_dtype;
mean_ = allocator->Malloc(mean_size);
var_ = allocator->Malloc(mean_size);
mean_ = allocator->Malloc(mean_size, lite::opencl::MemType::BUF);
var_ = allocator->Malloc(mean_size, lite::opencl::MemType::BUF);
std::string kernel_name = "LayerNormalization_NHWC4";
std::string kernel_name_mean_var = "ComputeMeanVar";
std::string source = layer_norm_source;

View File

@ -130,7 +130,7 @@ int MatMulOpenCLKernel::InitWeights() {
int b = weight_shape_4d[1];
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size);
padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size, lite::opencl::MemType::BUF);
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
auto padWeightFp32 = reinterpret_cast<float *>(padWeight_);
auto padWeightFp16 = reinterpret_cast<float16_t *>(padWeight_);

View File

@ -46,7 +46,7 @@ int PReluOpenCLKernel::InitWeights() {
int C_ = weight_tensor->ElementsNum();
auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float);
size_t weight_size = UP_ROUND(C_, C4NUM) * sizeof_FLT;
weight_vector_ = allocator->Malloc(weight_size);
weight_vector_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true);
memset(weight_vector_, 0x00, weight_size);
if (weight_tensor->data_type() == kNumberTypeFloat16) {

View File

@ -62,7 +62,7 @@ int SparseToDenseOpenCLKernel::InitWeights() {
} else {
auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float);
size_t weight_size = UP_ROUND(size, C4NUM) * sizeof_FLT;
weight_vector_ = allocator->Malloc(weight_size);
weight_vector_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true);
memset(weight_vector_, 0x00, weight_size);
if (weight_tensor->data_type() == kNumberTypeFloat16) {

View File

@ -94,13 +94,13 @@ void SplitOpenCLKernel::AlignSplitSizes(SplitParameter *param, const std::vector
int shape_dim = in_shape.at(param->split_dim_);
if (num_split_ == 1) {
size_t num_split = UP_DIV(shape_dim, param->split_sizes_[0]);
split_sizes_ = reinterpret_cast<int *>(allocator->Malloc(num_split * sizeof(int)));
split_sizes_ = reinterpret_cast<int *>(allocator->Malloc(num_split * sizeof(int), lite::opencl::MemType::BUF));
for (int i = 0; i < num_split - 1; ++i) {
split_sizes_[i] = (i + 1) * param->split_sizes_[0];
}
} else {
int sum = 0;
split_sizes_ = reinterpret_cast<int *>(allocator->Malloc(num_split_ * sizeof(int)));
split_sizes_ = reinterpret_cast<int *>(allocator->Malloc(num_split_ * sizeof(int), lite::opencl::MemType::BUF));
for (int i = 0; i < num_split_ - 1; ++i) {
sum += param->split_sizes_[i];
split_sizes_[i] = sum;

View File

@ -55,7 +55,7 @@ void StrassenOpenCLKernel::AllocatorMemoryForStrassen(int NumA, int NumB) {
size_t dtype_size = enable_fp16_ ? sizeof(cl_half) : sizeof(cl_float);
size_t memB = NumB * NumB * dtype_size;
for (int depth = 0; depth < MAXDEPTH; depth++) {
B_temp[depth] = allocator->Malloc(memB);
B_temp[depth] = allocator->Malloc(memB, lite::opencl::MemType::BUF);
A_temp[depth] = allocator->Malloc(img_size);
M1[depth] = allocator->Malloc(img_size);
M2[depth] = allocator->Malloc(img_size);
@ -73,7 +73,7 @@ int StrassenOpenCLKernel::InitWeights() {
int NumA = in_tensors_[0]->shape()[0];
int NumB = in_tensors_[1]->shape()[0];
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(NumA * NumB * dtype_size);
padWeight_ = allocator->Malloc(NumA * NumB * dtype_size, lite::opencl::MemType::BUF);
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
auto padWeightFp32 = reinterpret_cast<float *>(padWeight_);
auto padWeightFp16 = reinterpret_cast<float16_t *>(padWeight_);

View File

@ -102,7 +102,7 @@ void WinogradOpenCLKernel::InitFilter() {
packed_filter_ = allocator->Malloc({width, height, dtype});
} else {
size = UP_DIV(CO_SLICES_, Ogroup) * 6 * 6 * CI_SLICES_ * Ogroup * CI_TILE * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size);
packed_filter_ = allocator->Malloc(size, MemType::BUF);
}
// rearrange filter

View File

@ -74,7 +74,7 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
#ifdef SUBGRAPH_SPLIT
auto search_sub_graph = SearchSubGraph(src_model_, this->graph_output_node_indexes_);
auto search_sub_graph = SearchSubGraph(context_, src_model_, this->graph_output_node_indexes_);
search_sub_graph.SubGraphSplitByOutput();
#endif
@ -357,6 +357,7 @@ kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_ten
const std::vector<Tensor *> &out_tensors, OpParameter *op_parameter,
const kernel::KernelKey &desc) {
MS_ASSERT(op_parameter != nullptr);
if (context_->IsGpuEnabled()) {
// support more data type like int32
kernel::KernelKey gpu_desc{kGPU, kNumberTypeFloat32, desc.type};
@ -433,6 +434,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
kernel::LiteKernel *kernel = nullptr;
#ifdef SUPPORT_GPU
// if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) {
kernel = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc);
if (kernel != nullptr) {
return kernel;
@ -447,8 +449,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
return nullptr;
}
}
// }
#endif
#ifdef SUPPORT_NPU
// if (node->device_type_ == DT_NPU || node->device_type_ == DEFAULT) {
kernel = FindNpuKernel(in_tensors, out_tensors, op_parameter, desc);
if (kernel != nullptr) {
return kernel;
@ -463,6 +467,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
return nullptr;
}
}
// }
#endif
if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) {
kernel = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16);
@ -617,34 +622,37 @@ bool Scheduler::KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_typ
}
std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels(
kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map) {
MS_ASSERT(head_kernel != nullptr);
MS_ASSERT(sinked_kernel_map != nullptr);
std::vector<kernel::LiteKernel *> head_kernels, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map) {
std::vector<kernel::LiteKernel *> sub_kernels;
if (head_kernel->Type() == schema::PrimitiveType_Switch || head_kernel->Type() == schema::PrimitiveType_Merge) {
(*sinked_kernel_map)[head_kernel] = true;
sub_kernels.emplace_back(head_kernel);
return sub_kernels;
}
std::queue<kernel::LiteKernel *> kernel_queue;
kernel_queue.emplace(head_kernel);
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel);
while (!kernel_queue.empty()) {
auto cur_kernel = kernel_queue.front();
kernel_queue.pop();
(*sinked_kernel_map)[cur_kernel] = true;
sub_kernels.emplace_back(cur_kernel);
auto post_kernels = cur_kernel->out_kernels();
for (auto post_kernel : post_kernels) {
if (post_kernel->subgraph_type() != kernel::kNotSubGraph || post_kernel->Type() == schema::PrimitiveType_Merge ||
post_kernel->Type() == schema::PrimitiveType_Switch) {
continue;
}
if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) {
auto post_kernel_inputs = post_kernel->in_kernels();
if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(),
[&](kernel::LiteKernel *kernel) { return (*sinked_kernel_map)[kernel]; })) {
kernel_queue.emplace(post_kernel);
for (kernel::LiteKernel *head_kernel : head_kernels) {
MS_ASSERT(head_kernel != nullptr);
MS_ASSERT(sinked_kernel_map != nullptr);
if (head_kernel->Type() == schema::PrimitiveType_Switch || head_kernel->Type() == schema::PrimitiveType_Merge) {
(*sinked_kernel_map)[head_kernel] = true;
sub_kernels.emplace_back(head_kernel);
return sub_kernels;
}
std::queue<kernel::LiteKernel *> kernel_queue;
kernel_queue.emplace(head_kernel);
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel);
while (!kernel_queue.empty()) {
auto cur_kernel = kernel_queue.front();
kernel_queue.pop();
(*sinked_kernel_map)[cur_kernel] = true;
sub_kernels.emplace_back(cur_kernel);
auto post_kernels = cur_kernel->out_kernels();
for (auto post_kernel : post_kernels) {
if (post_kernel->subgraph_type() != kernel::kNotSubGraph ||
post_kernel->Type() == schema::PrimitiveType_Merge || post_kernel->Type() == schema::PrimitiveType_Switch) {
continue;
}
if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) {
auto post_kernel_inputs = post_kernel->in_kernels();
if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(),
[&](kernel::LiteKernel *kernel) { return (*sinked_kernel_map)[kernel]; })) {
kernel_queue.emplace(post_kernel);
}
}
}
}
@ -659,11 +667,15 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel,
(*is_kernel_finish)[kernel] = false;
}
while (true) {
std::vector<kernel::LiteKernel *> head_kernels;
auto head_kernel_iter = std::find_if(src_kernel.begin(), src_kernel.end(), [&](const kernel::LiteKernel *kernel) {
auto kernel_inputs = kernel->in_kernels();
if ((*is_kernel_finish)[kernel]) {
return false;
}
if (std::find(head_kernels.begin(), head_kernels.end(), kernel) != head_kernels.end()) {
return false;
}
// when merge is removed, this if is removed automatically
if (kernel->Type() == schema::PrimitiveType_Merge) {
return MergeOpIsReady(kernel, (*is_kernel_finish));
@ -675,25 +687,33 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel,
if (head_kernel_iter == src_kernel.end()) {
break;
}
auto head_kernel = *head_kernel_iter;
if (head_kernel->subgraph_type() != kernel::kNotSubGraph) {
(*is_kernel_finish)[head_kernel] = true;
dst_kernel->push_back(head_kernel);
/* npu support split */
/* ConstructSubGraphs(head_kernel->nodes(), dst_kernel, is_kernel_finish); */
continue;
}
if (head_kernel->desc().arch == mindspore::kernel::kAPU) {
MS_LOG(ERROR) << "Not support APU now";
return RET_NOT_SUPPORT;
}
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel);
auto sub_kernels = FindAllSubGraphKernels(head_kernel, is_kernel_finish);
head_kernels.push_back(head_kernel);
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernels[0]);
auto sub_kernels = FindAllSubGraphKernels(head_kernels, is_kernel_finish);
auto subgraph = CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type);
if (subgraph == nullptr) {
MS_LOG(ERROR) << "Create SubGraphKernel failed";
return RET_ERROR;
}
dst_kernel->emplace_back(subgraph);
}
} /* end when all kernel converted */
for (auto *subgraph : *dst_kernel) {
auto ret = subgraph->Init();
if (ret != RET_OK) {
@ -702,7 +722,7 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel,
}
}
return RET_OK;
}
} // namespace mindspore::lite
bool Scheduler::MergeOpIsReady(const kernel::LiteKernel *kernel,
std::map<const kernel::LiteKernel *, bool> is_kernel_finish) {

View File

@ -92,7 +92,7 @@ class Scheduler {
bool KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::LiteKernel &kernel);
std::vector<kernel::LiteKernel *> FindAllSubGraphKernels(
kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map);
std::vector<kernel::LiteKernel *> head_kernels, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map);
// other methods
static TypeId GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors);

View File

@ -18,11 +18,10 @@
#include <vector>
#include <utility>
#include "src/tensor.h"
#include "schema/inner/ops_generated.h"
#include "schema/inner/model_generated.h"
#include "schema/ops_generated.h"
#include "schema/model_generated.h"
namespace mindspore::lite {
#ifdef SUBGRAPH_SPLIT
const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph_index) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreatePartialFusion(fbb, subgraph_index);
@ -46,7 +45,7 @@ void SearchSubGraph::ConvertSubGraphToModel() {
if (subgraph.nodes_.empty()) {
continue;
}
mindspore::kernel::KERNEL_ARCH device = subgraph.device_;
// DeviceType device = subgraph.device_;
int new_sub_index = model_->sub_graphs_.size();
int partial_index = model_->all_nodes_.size();
@ -72,7 +71,7 @@ void SearchSubGraph::ConvertSubGraphToModel() {
new_sub_graph->node_indices_.push_back(node_index);
VectorErase(&main_graphs->node_indices_, node_index);
VectorErase(&subgraph.nodes_, node_index);
model_->all_nodes_[node_index]->device_type_ = device;
// model_->all_nodes_[node_index]->device_type_ = device;
}
for (uint32_t head_index : subgraph.heads_) {
@ -134,7 +133,7 @@ void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph) {
/* remove const node */
for (int i = input.size() - 1; i >= 0; i--) {
if (tensors_[input[i]].type_ == CONST) {
input.erase(input.begin() + i);
VectorErase(&input, input[i]);
}
}
@ -223,36 +222,62 @@ void SearchSubGraph::InitSearchTensor() {
}
void SearchSubGraph::InitSubgraphDevice() {
sub_graphs_[0].device_ = kernel::KERNEL_ARCH::kCPU;
sub_graphs_[1].device_ = kernel::KERNEL_ARCH::kALL;
}
void SearchSubGraph::InitMainGraphDevice() {
kernel::KERNEL_ARCH main_device = kernel::KERNEL_ARCH::kALL;
Model::SubGraph *main_graph = model_->sub_graphs_.front();
for (uint32_t node_index : main_graph->node_indices_) {
Model::Node *node = model_->all_nodes_[node_index];
node->device_type_ = main_device;
for (size_t i = 0; i < sub_graphs_.size(); i++) {
sub_graphs_[i].device_ = (i % 2 == 0) ? DT_CPU : DT_GPU;
}
}
void SearchSubGraph::InitMainGraphDevice() {
// DeviceType main_device = DT_GPU;
// Model::SubGraph *main_graph = model_->sub_graphs_.front();
// for (uint32_t node_index : main_graph->node_indices_) {
// Model::Node *node = model_->all_nodes_[node_index];
// node->device_type_ = main_device;
}
void SearchSubGraph::SubgraphFusion() {
Subgraph new_npu_sub;
Subgraph &npu_sub1 = sub_graphs_[1];
Subgraph &npu_sub2 = sub_graphs_[2];
new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub1.nodes_.begin(), npu_sub1.nodes_.end());
new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub2.nodes_.begin(), npu_sub2.nodes_.end());
new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub1.heads_.begin(), npu_sub1.heads_.end());
new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub2.heads_.begin(), npu_sub2.heads_.end());
new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub1.ends_.begin(), npu_sub1.ends_.end());
new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub2.ends_.begin(), npu_sub2.ends_.end());
sub_graphs_.erase(sub_graphs_.begin() + 2);
sub_graphs_.erase(sub_graphs_.begin() + 1);
sub_graphs_.insert(sub_graphs_.end(), std::move(new_npu_sub));
while (sub_graphs_.size() > 2) {
size_t sub1_index = 0;
int sub2_index = -1;
for (; sub1_index < sub_graphs_.size(); sub1_index++) {
for (size_t tmp2 = sub1_index + 1; tmp2 < sub_graphs_.size(); tmp2++) {
if (sub_graphs_[sub1_index].device_ == sub_graphs_[tmp2].device_) {
sub2_index = tmp2;
break;
}
}
if (sub2_index != -1) {
break;
}
}
MS_ASSERT(sub2_index > sub1_index);
Subgraph new_npu_sub;
Subgraph &npu_sub1 = sub_graphs_[sub1_index];
Subgraph &npu_sub2 = sub_graphs_[sub2_index];
new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub1.nodes_.begin(), npu_sub1.nodes_.end());
new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub2.nodes_.begin(), npu_sub2.nodes_.end());
new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub1.heads_.begin(), npu_sub1.heads_.end());
new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub2.heads_.begin(), npu_sub2.heads_.end());
new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub1.ends_.begin(), npu_sub1.ends_.end());
new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub2.ends_.begin(), npu_sub2.ends_.end());
sub_graphs_.erase(sub_graphs_.begin() + sub2_index);
sub_graphs_.erase(sub_graphs_.begin() + sub1_index);
sub_graphs_.insert(sub_graphs_.end(), std::move(new_npu_sub));
}
return;
}
void SearchSubGraph::SubGraphSplitByOutput() {
if (!context_->IsGpuEnabled() || output_nodes_.size() > 4) {
return;
}
if (context_->IsCpuFloat16Enabled() || context_->IsGpuFloat16Enabled()) {
return;
}
InitSearchTensor();
InitSearchSubGraph();
@ -265,5 +290,4 @@ void SearchSubGraph::SubGraphSplitByOutput() {
InitMainGraphDevice();
}
#endif
} // namespace mindspore::lite

View File

@ -22,9 +22,9 @@
#include "include/model.h"
#include "src/lite_kernel.h"
#include "src/lite_model.h"
#include "src/inner_context.h"
namespace mindspore::lite {
#ifdef SUBGRAPH_SPLIT
class SearchSubGraph {
enum TensorType { NORMAL, CONST, INPUT };
@ -39,11 +39,11 @@ class SearchSubGraph {
std::vector<uint32_t> heads_;
std::vector<uint32_t> ends_;
bool search_terminate_ = false;
mindspore::kernel::KERNEL_ARCH device_;
DeviceType device_;
};
public:
SearchSubGraph(Model *model, std::vector<size_t> output_nodes) {
SearchSubGraph(const InnerContext *context, Model *model, std::vector<size_t> output_nodes) : context_(context) {
output_nodes_.insert(output_nodes_.end(), output_nodes.begin(), output_nodes.end());
node_list_ = model->all_nodes_;
model_ = reinterpret_cast<LiteModel *>(model);
@ -65,14 +65,13 @@ class SearchSubGraph {
void InitMainGraphDevice();
private:
const InnerContext *context_ = nullptr;
LiteModel *model_ = nullptr;
std::vector<Tensor> tensors_;
std::vector<Subgraph> sub_graphs_;
std::vector<size_t> output_nodes_;
std::vector<Model::Node *> node_list_;
};
#endif
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_