diff --git a/include/api/model.h b/include/api/model.h index ac36a76e0b4..9c0b434f0ab 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -38,7 +38,6 @@ namespace dataset { class Dataset; } // namespace dataset - class MS_API Model { public: Model(); @@ -72,7 +71,10 @@ class MS_API Model { Status Evaluate(std::shared_ptr ds, std::vector cbs); Status Build(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, - const std::string &dec_mode = "AES-GCM"); + const std::string &dec_mode = kDecModeAesGcm); + Status Build(const std::string &model_path, ModelType model_type, + const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, + const std::string &dec_mode = kDecModeAesGcm); private: friend class Serialization; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/cast_base.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/cast_base.h index caecc7e9297..56c6c265571 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/cast_base.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/base/cast_base.h @@ -53,6 +53,24 @@ inline void Int64ToFp16(const int64_t *input, float16_t *output, int number) { output[i] = (float16_t)input[i]; } } + +inline void Int32ToFp16(const int32_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void BoolToFp16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Uint8ToFp16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} #endif inline void Fp16ToFloat32(const uint16_t *input, float *output, int number) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c index 887e5e7785c..e80bd4e11d5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c @@ -75,6 +75,9 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * } iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); } + if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) { + return NNACL_ERR; + } int c_shape[MAX_SHAPE_SIZE]; size_t c_shape_size = 0; ShapeSet(c_shape, &c_shape_size, a_shape, a_shape_size); diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 4d33dcae848..699d68a1126 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -70,6 +70,13 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty MS_LOG(ERROR) << "Unsupported Feature."; return kMCFailed; } + +Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, + const Key &dec_key, const std::string &dec_mode) { + MS_LOG(ERROR) << "Unsupported Feature."; + return kMCFailed; +} + Status Model::Resize(const std::vector &inputs, const std::vector> &dims) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Failed because this model has not been built."; diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index d8c450b2abf..a0910793eb3 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -40,6 +40,20 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty return kSuccess; } +Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, + const Key &dec_key, const std::string &dec_mode) { + impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteNullptr; + } + Status ret = impl_->Build(model_path, model_type, model_context); + if (ret != kSuccess) { + return ret; + } + return kSuccess; +} + Status Model::Build(GraphCell graph, const std::shared_ptr &model_context, const std::shared_ptr &train_cfg) { std::stringstream err_msg; diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index eb5c81d26e0..e603992f88c 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -28,6 +28,7 @@ #include "src/cxx_api/tensor_utils.h" #include "src/common/log_adapter.h" #include "src/train/train_session.h" +#include "src/common/file_utils.h" namespace mindspore { using mindspore::lite::RET_ERROR; @@ -61,6 +62,25 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode return kSuccess; } +Status ModelImpl::Build(const std::string &model_path, ModelType model_type, + const std::shared_ptr &ms_context) { + lite::Context lite_context; + auto status = A2L_ConvertContext(ms_context.get(), &lite_context); + if (status != kSuccess) { + return status; + } + + auto session = std::shared_ptr(lite::LiteSession::CreateSession(model_path, &lite_context)); + if (session == nullptr) { + MS_LOG(ERROR) << "Allocate session failed."; + return kLiteNullptr; + } + + session_.swap(session); + MS_LOG(DEBUG) << "Build model success."; + return kSuccess; +} + Status ModelImpl::Build() { MS_LOG(DEBUG) << "Start build model."; if (graph_ == nullptr || graph_->graph_data_ == nullptr) { diff --git a/mindspore/lite/src/cxx_api/model/model_impl.h b/mindspore/lite/src/cxx_api/model/model_impl.h index f4abf0c4968..0f1422d3e38 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.h +++ b/mindspore/lite/src/cxx_api/model/model_impl.h @@ -60,6 +60,7 @@ class ModelImpl { Status Build(); Status Build(const void *model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context); + Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context); Status Resize(const std::vector &inputs, const std::vector> &dims); Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, diff --git a/mindspore/lite/src/lite_model.h b/mindspore/lite/src/lite_model.h index 92b06a8d87f..f0a971318a4 100644 --- a/mindspore/lite/src/lite_model.h +++ b/mindspore/lite/src/lite_model.h @@ -47,6 +47,10 @@ class LiteModel : public Model { ~LiteModel() override { Destroy(); } + bool keep_model_buf() const { return this->keep_model_buf_; } + + void set_keep_model_buf(bool keep) { this->keep_model_buf_ = keep; } + private: #ifdef ENABLE_V0 int ConvertAttrs(Model::Node *node, std::vector *dst_tensor); @@ -260,6 +264,7 @@ class LiteModel : public Model { protected: std::vector attr_tensor_bufs_; + bool keep_model_buf_ = false; }; Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index ce754efd8f3..b64b31a2dea 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -27,6 +27,7 @@ #include "src/common/prim_util.h" #include "src/common/graph_util.h" #include "src/common/tensor_util.h" +#include "src/common/file_utils.h" #include "src/kernel_registry.h" #include "src/lite_model.h" #include "src/weight_decoder.h" @@ -984,4 +985,31 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, (reinterpret_cast(session))->set_model(model); return session; } + +session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) { + size_t model_size; + auto model_buf = lite::ReadFile(model_path.c_str(), &model_size); + if (model_buf == nullptr) { + MS_LOG(ERROR) << "Read model file failed"; + return nullptr; + } + auto *session = session::LiteSession::CreateSession(context); + if (session == nullptr) { + MS_LOG(ERROR) << "Create session failed"; + return nullptr; + } + auto *model = lite::ImportFromBuffer(model_buf, model_size, true); + if (model == nullptr) { + MS_LOG(ERROR) << "Import model failed"; + return nullptr; + } + (reinterpret_cast(model))->set_keep_model_buf(true); + auto ret = session->CompileGraph(model); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "Compile model failed"; + return nullptr; + } + (reinterpret_cast(session))->set_model(model); + return session; +} } // namespace mindspore diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 825ea9c2525..8d84df1b62f 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -47,6 +47,8 @@ class LiteSession : public session::LiteSession { ~LiteSession() override; + static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context); + virtual int Init(const Context *context); void BindThread(bool if_bind) override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc index ede9bcd7113..6b4565b3c85 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to_fp32.cc @@ -46,10 +46,7 @@ int BroadcastToCPUKernel::ReSize() { shape_info_->output_shape_size_ = static_cast(output_shape.size()); data_type_ = in_tensors_.at(0)->data_type(); - if (data_type_ != out_tensors_.at(0)->data_type()) { - MS_LOG(ERROR) << "BroadcastTo infer has error"; - return RET_ERROR; - } + MS_ASSERT(data_type_ == out_tensors_.at(0)->data_type()); return RET_OK; } @@ -80,14 +77,14 @@ int BroadcastToCPUKernel::Run() { } switch (data_type_) { case kNumberTypeFloat32: { - const auto input_data = reinterpret_cast(in_tensors_.at(0)->MutableData()); - auto output_data = reinterpret_cast(out_tensors_.at(0)->MutableData()); + const auto input_data = reinterpret_cast(in_tensors_.at(0)->data_c()); + auto output_data = reinterpret_cast(out_tensors_.at(0)->data_c()); return BroadcastTo(float, input_data, shape_info_, output_data); } case kNumberTypeInt32: case kNumberTypeInt: { - const auto input_data = reinterpret_cast(in_tensors_.at(0)->MutableData()); - auto output_data = reinterpret_cast(out_tensors_.at(0)->MutableData()); + const auto input_data = reinterpret_cast(in_tensors_.at(0)->data_c()); + auto output_data = reinterpret_cast(out_tensors_.at(0)->data_c()); return BroadcastTo(int, input_data, shape_info_, output_data); } default: diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc index 4e8591e9952..d7f5a75e63b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc @@ -83,6 +83,64 @@ int CastCPUKernel::CastToFp32(lite::Tensor *input, lite::Tensor *output, int off return RET_OK; } +int CastCPUKernel::CastToFp16(lite::Tensor *input, lite::Tensor *output, int offset, int data_num) { + auto input_data_type = input->data_type(); + auto output_data = output->data_c(); + switch (input_data_type) { + case kNumberTypeFloat32: + Float32ToFp16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; +#ifdef ENABLE_FP16 + case kNumberTypeInt64: + Int64ToFp16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + case kNumberTypeInt32: + Int32ToFp16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeBool: + BoolToFp16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFp16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; +#endif + default: + MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; + return RET_ERROR; + } + return RET_OK; +} + +int CastCPUKernel::CastToOthers(lite::Tensor *input, lite::Tensor *output, int offset, int data_num) { + auto input_data_type = input->data_type(); + auto output_data_type = output->data_type(); + auto output_data = output->data_c(); + if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { + Float32ToInt64(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { + Float32ToInt32(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { + Int32ToInt64(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt16) { + Float32ToInt16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { + BoolToInt32(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, + data_num); + } else { + MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; + return RET_ERROR; + } + return RET_OK; +} + int CastCPUKernel::DoCast(int thread_id) { auto input = in_tensors_.at(0); int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); @@ -102,38 +160,13 @@ int CastCPUKernel::DoCast(int thread_id) { reinterpret_cast(input->data_c()) + offset * datalen, data_num * datalen); return RET_OK; } - if (output_data_type != kNumberTypeFloat32) { - if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { - Float32ToInt64(reinterpret_cast(input->data_c()) + offset, - reinterpret_cast(output_data) + offset, data_num); - } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { - Float32ToInt32(reinterpret_cast(input->data_c()) + offset, - reinterpret_cast(output_data) + offset, data_num); - } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) { - Float32ToFp16(reinterpret_cast(input->data_c()) + offset, - reinterpret_cast(output_data) + offset, data_num); - } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { - Int32ToInt64(reinterpret_cast(input->data_c()) + offset, - reinterpret_cast(output_data) + offset, data_num); - } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt16) { - Float32ToInt16(reinterpret_cast(input->data_c()) + offset, - reinterpret_cast(output_data) + offset, data_num); - } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { - BoolToInt32(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, - data_num); -#ifdef ENABLE_FP16 - } else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeFloat16) { - Int64ToFp16(reinterpret_cast(input->data_c()) + offset, - reinterpret_cast(output_data) + offset, data_num); -#endif - } else { - MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; - return RET_ERROR; - } - } else { + if (output_data_type == kNumberTypeFloat32) { return CastToFp32(input, output, offset, data_num); + } else if (output_data_type == kNumberTypeFloat16) { + return CastToFp16(input, output, offset, data_num); + } else { + return CastToOthers(input, output, offset, data_num); } - return RET_OK; } int CastCPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h index 2f2e1a551ec..87a527eb44b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.h @@ -39,6 +39,8 @@ class CastCPUKernel : public InnerKernel { private: int CastToFp32(lite::Tensor *input, lite::Tensor *output, int offset, int data_num); + int CastToFp16(lite::Tensor *input, lite::Tensor *output, int offset, int data_num); + int CastToOthers(lite::Tensor *input, lite::Tensor *output, int offset, int data_num); int stride_ = 0; int data_num_ = 0; }; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 2fe97939e3d..e037c42a08e 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -656,7 +656,7 @@ int Scheduler::FindCpuKernel(const std::vector &in_tensors, const std: MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; return RET_NOT_SUPPORT; } - if (!is_train_session_) { + if (!is_train_session_ && !(reinterpret_cast(src_model_)->keep_model_buf())) { // we don't need to restore tensor for copy data ret = CopyConstTensorData(in_tensors, op_type); if (ret != RET_OK) { @@ -834,6 +834,122 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in return nullptr; } +namespace { +kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector &kernels, + const std::vector *in_tensors, + const std::vector *out_tensors, kernel::SubGraphType type, + const InnerContext &context) { + if (type == kernel::kApuSubGraph) { + return nullptr; + } + std::vector input_tensors; + std::vector output_tensors; + if (in_tensors != nullptr) { + input_tensors = *in_tensors; + } else { + input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels); + } + if (out_tensors != nullptr) { + output_tensors = *out_tensors; + } else { + output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); + } + auto innerkernel = new (std::nothrow) kernel::InnerKernel(nullptr, input_tensors, output_tensors, &context); + if (innerkernel == nullptr) { + return nullptr; + } + std::vector input_kernels = kernel::LiteKernelUtil::SubgraphInputNodes(kernels); + std::vector output_kernels = kernel::LiteKernelUtil::SubgraphOutputNodes(kernels); + kernel::SubGraphKernel *sub_graph = nullptr; + if (type == kernel::kCustomSubGraph) { + sub_graph = CreateCustomSubGraph(std::move(input_kernels), std::move(output_kernels), kernels, innerkernel); + } + if (type == kernel::kGpuSubGraph) { +#if GPU_OPENCL + sub_graph = new (std::nothrow) kernel::OpenCLSubGraph(input_kernels, output_kernels, kernels, innerkernel); + if (sub_graph == nullptr) { + MS_LOG(ERROR) << "Create OpenCLSubGraph failed"; + delete innerkernel; + return nullptr; + } +#elif GPU_VULKAN + delete innerkernel; + return nullptr; +#else + delete innerkernel; + return nullptr; +#endif + } + if (type == kernel::kCpuFP16SubGraph) { +#ifdef ENABLE_FP16 + sub_graph = new (std::nothrow) kernel::CpuFp16SubGraph(input_kernels, output_kernels, kernels, innerkernel); + if (sub_graph == nullptr) { + MS_LOG(ERROR) << "FP16 subgraph new failed."; + delete innerkernel; + return nullptr; + } + for (auto out_tensor : output_tensors) { + if (out_tensor->data_type() == kNumberTypeFloat32) { + out_tensor->set_data_type(kNumberTypeFloat16); + } + } +#else + delete innerkernel; + MS_LOG(ERROR) << "FP16 subgraph is not supported!"; + return nullptr; +#endif + } + if (type == kernel::kCpuFP32SubGraph) { + sub_graph = new (std::nothrow) kernel::CpuFp32SubGraph(input_kernels, output_kernels, kernels, innerkernel); + if (sub_graph == nullptr) { + MS_LOG(ERROR) << "FP32 subgraph new failed."; + delete innerkernel; + return nullptr; + } + } + if (sub_graph == nullptr) { + MS_LOG(ERROR) << "create sub graph failed."; + return nullptr; + } + sub_graph->set_context(&context); + return sub_graph; +} + +kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel, const InnerContext &context, + bool is_controlflow = false) { + if (kernel == nullptr) { + return kernel::kNotSubGraph; + } + + auto desc = kernel->desc(); + if (desc.provider != kernel::kBuiltin) { + return kernel::kCustomSubGraph; + } + if (desc.arch == kernel::KERNEL_ARCH::kGPU) { + return kernel::kGpuSubGraph; + } else if (desc.arch == kernel::KERNEL_ARCH::kNPU) { + return kernel::kNpuSubGraph; + } else if (desc.arch == kernel::KERNEL_ARCH::kAPU) { + return kernel::kApuSubGraph; + } else if (desc.arch == kernel::KERNEL_ARCH::kCPU) { + if (desc.data_type == kNumberTypeFloat16) { + return kernel::kCpuFP16SubGraph; + } else if (desc.data_type == kNumberTypeFloat32 || desc.data_type == kNumberTypeInt8 || + desc.data_type == kNumberTypeInt64 || desc.data_type == kNumberTypeUInt8 || + desc.data_type == kNumberTypeBool) { + return kernel::kCpuFP32SubGraph; + } else if (desc.data_type == kNumberTypeInt32) { + if (context.IsCpuFloat16Enabled() && !is_controlflow) { + return kernel::kCpuFP16SubGraph; + } else { + return kernel::kCpuFP32SubGraph; + } + } + } + return kernel::kNotSubGraph; +} +} // namespace + kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *src_node) { MS_ASSERT(src_model_ != nullptr); MS_ASSERT(src_node != nullptr); @@ -917,9 +1033,9 @@ kernel::LiteKernel *Scheduler::SchedulePartialToSubGraphKernel(const int &subgra return {}; } FindAllInoutKernels(kernels); - auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(kernels.front()); + auto cur_sub_graph_type = GetKernelSubGraphType(kernels.front(), *context_, true); MS_LOG(INFO) << "cur_sub_graph_type: " << cur_sub_graph_type; - auto subgraph_kernel = CreateSubGraphKernel(kernels, &in_tensors, &out_tensors, cur_sub_graph_type); + auto subgraph_kernel = CreateSubGraphKernel(kernels, &in_tensors, &out_tensors, cur_sub_graph_type, *context_); if (subgraph_kernel == nullptr) { MS_LOG(ERROR) << "CreateSubGraphKernel failed, cur_sub_graph_type: " << cur_sub_graph_type; return nullptr; @@ -1043,9 +1159,10 @@ int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vectorsrc_tensors_->at(index); }); } return RET_OK; -} // namespace mindspore::lite +} -bool Scheduler::KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::LiteKernel &kernel) { +namespace { +bool KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::LiteKernel &kernel) { switch (subgraph_type) { case kernel::SubGraphType::kNotSubGraph: case kernel::SubGraphType::kApuSubGraph: @@ -1077,91 +1194,79 @@ bool Scheduler::KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_typ } } -std::vector Scheduler::FindAllSubGraphKernels( - std::vector head_kernels, std::map *sinked_kernel_map) { +kernel::LiteKernel *FindAllSubGraphKernels(const std::vector &sorted_kernels, + const InnerContext &context, size_t *cur_index) { std::vector sub_kernels; - - for (kernel::LiteKernel *head_kernel : head_kernels) { - MS_ASSERT(head_kernel != nullptr); - MS_ASSERT(sinked_kernel_map != nullptr); - std::queue kernel_queue; - kernel_queue.emplace(head_kernel); - auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); - while (!kernel_queue.empty()) { - auto cur_kernel = kernel_queue.front(); - kernel_queue.pop(); - (*sinked_kernel_map)[cur_kernel] = true; - sub_kernels.emplace_back(cur_kernel); - auto post_kernels = cur_kernel->out_kernels(); - for (auto post_kernel : post_kernels) { - if (post_kernel->subgraph_type() != kernel::kNotSubGraph) { - continue; - } - if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) { - auto post_kernel_inputs = post_kernel->in_kernels(); - if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(), - [&](kernel::LiteKernel *kernel) { return (*sinked_kernel_map)[kernel]; })) { - kernel_queue.emplace(post_kernel); - } - } - } + sub_kernels.emplace_back(sorted_kernels[*cur_index]); + auto cur_sub_graph_type = GetKernelSubGraphType(sorted_kernels[*cur_index], context); + for (*cur_index = *cur_index + 1; *cur_index < sorted_kernels.size(); ++(*cur_index)) { + auto cur_kernel = sorted_kernels[*cur_index]; + MS_ASSERT(GetKernelSubGraphType(cur_kernel, context) != kernel::kApuSubGraph); + // already a subgraph or a delegate + if (cur_kernel->subgraph_type() != kernel::kNotSubGraph || cur_kernel->desc().delegate != nullptr) { + --(*cur_index); + break; } + if (!KernelFitCurrentSubGraph(cur_sub_graph_type, *cur_kernel)) { + --(*cur_index); + break; + } + sub_kernels.emplace_back(cur_kernel); } - return sub_kernels; + return CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type, context); } +} // namespace int Scheduler::ConstructSubGraphs(std::vector src_kernel, std::vector *dst_kernel, std::map *is_kernel_finish) { - for (auto kernel : src_kernel) { - (*is_kernel_finish)[kernel] = false; + if (src_kernel.empty()) { + return RET_OK; } - while (true) { - std::vector head_kernels; /* support one-head-kernel in subgraph */ - auto head_kernel_iter = std::find_if(src_kernel.begin(), src_kernel.end(), [&](const kernel::LiteKernel *kernel) { - auto kernel_inputs = kernel->in_kernels(); - if ((*is_kernel_finish)[kernel]) { - return false; - } - if (std::find(head_kernels.begin(), head_kernels.end(), kernel) != head_kernels.end()) { - return false; - } - return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), - [&](kernel::LiteKernel *kernel) { return (*is_kernel_finish)[kernel]; }); - }); - if (head_kernel_iter == src_kernel.end()) { - break; + // topological sort + std::vector sorted_kernels; + for (auto iter = src_kernel.begin(); iter != src_kernel.end();) { + if ((*iter)->in_kernels().empty()) { + sorted_kernels.emplace_back(*iter); + (*is_kernel_finish)[*iter] = true; + iter = src_kernel.erase(iter); + } else { + (*is_kernel_finish)[*iter] = false; + iter++; } - - auto head_kernel = *head_kernel_iter; - if (head_kernel->subgraph_type() != kernel::kNotSubGraph) { - (*is_kernel_finish)[head_kernel] = true; - dst_kernel->push_back(head_kernel); + } + while (!src_kernel.empty()) { + for (auto iter = src_kernel.begin(); iter != src_kernel.end();) { + auto kernel = *iter; + auto inputs = kernel->in_kernels(); + if (std::all_of(inputs.begin(), inputs.end(), + [&](const kernel::LiteKernel *kernel) { return (*is_kernel_finish)[kernel]; })) { + sorted_kernels.emplace_back(kernel); + (*is_kernel_finish)[*iter] = true; + iter = src_kernel.erase(iter); + } else { + iter++; + } + } + } + // construct subgraph + for (size_t index = 0; index < sorted_kernels.size(); index++) { + auto cur_kernel = sorted_kernels[index]; + MS_ASSERT(cur_kernel != nullptr); + // Not support APU now + MS_ASSERT(GetKernelSubGraphType(cur_kernel, *context_) != kernel::kApuSubGraph); + // already a subgraph or a delegate + if (cur_kernel->subgraph_type() != kernel::kNotSubGraph || cur_kernel->desc().delegate != nullptr) { + dst_kernel->emplace_back(cur_kernel); continue; } - if (head_kernel->desc().arch == mindspore::kernel::kAPU) { - MS_LOG(ERROR) << "Not support APU now"; - return RET_NOT_SUPPORT; + auto subgraph = FindAllSubGraphKernels(sorted_kernels, *context_, &index); + if (subgraph == nullptr) { + MS_LOG(ERROR) << "Create SubGraphKernel failed"; + return RET_ERROR; } - - head_kernels.push_back(head_kernel); - - auto subgraph_delegate = head_kernel->desc().delegate; - if (subgraph_delegate != nullptr) { - dst_kernel->emplace_back(head_kernel); - (*is_kernel_finish)[head_kernel] = true; - } else { - auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernels[0]); - auto sub_kernels = FindAllSubGraphKernels(head_kernels, is_kernel_finish); - auto subgraph = CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type); - if (subgraph == nullptr) { - MS_LOG(ERROR) << "Create SubGraphKernel failed"; - return RET_ERROR; - } - dst_kernel->emplace_back(subgraph); - } - } /* end when all kernel converted */ - + dst_kernel->emplace_back(subgraph); + } for (auto *subgraph : *dst_kernel) { auto subgraph_delegate = subgraph->desc().delegate; if (subgraph_delegate == nullptr) { @@ -1175,86 +1280,6 @@ int Scheduler::ConstructSubGraphs(std::vector src_kernel, return RET_OK; } -kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector &kernels, - const std::vector *in_tensors, - const std::vector *out_tensors, - kernel::SubGraphType type) { - if (type == kernel::kApuSubGraph) { - return nullptr; - } - std::vector input_tensors; - std::vector output_tensors; - if (in_tensors != nullptr) { - input_tensors = *in_tensors; - } else { - input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels); - } - if (out_tensors != nullptr) { - output_tensors = *out_tensors; - } else { - output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); - } - auto innerkernel = new (std::nothrow) kernel::InnerKernel(nullptr, input_tensors, output_tensors, context_); - if (innerkernel == nullptr) { - return nullptr; - } - std::vector input_kernels = kernel::LiteKernelUtil::SubgraphInputNodes(kernels); - std::vector output_kernels = kernel::LiteKernelUtil::SubgraphOutputNodes(kernels); - kernel::SubGraphKernel *sub_graph = nullptr; - if (type == kernel::kCustomSubGraph) { - sub_graph = CreateCustomSubGraph(std::move(input_kernels), std::move(output_kernels), kernels, innerkernel); - } - if (type == kernel::kGpuSubGraph) { -#if GPU_OPENCL - sub_graph = new (std::nothrow) kernel::OpenCLSubGraph(input_kernels, output_kernels, kernels, innerkernel); - if (sub_graph == nullptr) { - MS_LOG(ERROR) << "Create OpenCLSubGraph failed"; - delete innerkernel; - return nullptr; - } -#elif GPU_VULKAN - delete innerkernel; - return nullptr; -#else - delete innerkernel; - return nullptr; -#endif - } - if (type == kernel::kCpuFP16SubGraph) { -#ifdef ENABLE_FP16 - sub_graph = new (std::nothrow) kernel::CpuFp16SubGraph(input_kernels, output_kernels, kernels, innerkernel); - if (sub_graph == nullptr) { - MS_LOG(ERROR) << "FP16 subgraph new failed."; - delete innerkernel; - return nullptr; - } - for (auto out_tensor : output_tensors) { - if (out_tensor->data_type() == kNumberTypeFloat32) { - out_tensor->set_data_type(kNumberTypeFloat16); - } - } -#else - delete innerkernel; - MS_LOG(ERROR) << "FP16 subgraph is not supported!"; - return nullptr; -#endif - } - if (type == kernel::kCpuFP32SubGraph) { - sub_graph = new (std::nothrow) kernel::CpuFp32SubGraph(input_kernels, output_kernels, kernels, innerkernel); - if (sub_graph == nullptr) { - MS_LOG(ERROR) << "FP32 subgraph new failed."; - delete innerkernel; - return nullptr; - } - } - if (sub_graph == nullptr) { - MS_LOG(ERROR) << "create sub graph failed."; - return nullptr; - } - sub_graph->set_context(context_); - return sub_graph; -} - TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector &in_tensors) { for (const auto &tensor : in_tensors) { auto dtype = tensor->data_type(); @@ -1304,37 +1329,35 @@ void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) { } } -kernel::SubGraphType Scheduler::GetKernelSubGraphType(const kernel::LiteKernel *kernel) { - if (kernel == nullptr) { - return kernel::kNotSubGraph; - } - - auto desc = kernel->desc(); - if (desc.provider != kernel::kBuiltin) { - return kernel::kCustomSubGraph; - } - if (desc.arch == kernel::KERNEL_ARCH::kGPU) { - return kernel::kGpuSubGraph; - } else if (desc.arch == kernel::KERNEL_ARCH::kNPU) { - return kernel::kNpuSubGraph; - } else if (desc.arch == kernel::KERNEL_ARCH::kAPU) { - return kernel::kApuSubGraph; - } else if (desc.arch == kernel::KERNEL_ARCH::kCPU) { - if (desc.data_type == kNumberTypeFloat16) { - return kernel::kCpuFP16SubGraph; - } else if (desc.data_type == kNumberTypeFloat32 || desc.data_type == kNumberTypeInt8 || - desc.data_type == kNumberTypeInt32 || desc.data_type == kNumberTypeInt64 || - desc.data_type == kNumberTypeUInt8 || desc.data_type == kNumberTypeBool) { - return kernel::kCpuFP32SubGraph; +void Scheduler::FindAllInoutKernels(const std::vector &kernels) { + std::unordered_map tensorPreKernel; + std::unordered_map> tensorPostKernels; + for (auto *kernel : kernels) { + for (auto *tensor : kernel->out_tensors()) { + tensorPreKernel[tensor] = kernel; + } + for (auto *tensor : kernel->in_tensors()) { + (tensorPostKernels[tensor]).push_back(kernel); } } - return kernel::kNotSubGraph; -} -void Scheduler::FindAllInoutKernels(const std::vector &kernels) { for (auto *kernel : kernels) { - MS_ASSERT(kernel != nullptr); - kernel->FindInoutKernels(kernels); + kernel->set_in_kernels({}); + for (auto *tensor : kernel->in_tensors()) { + auto iter = tensorPreKernel.find(tensor); + if (iter != tensorPreKernel.end()) { + kernel->AddInKernel(iter->second); + } + } + kernel->set_out_kernels({}); + for (auto *tensor : kernel->out_tensors()) { + auto iter = tensorPostKernels.find(tensor); + if (iter != tensorPostKernels.end()) { + for (auto *find_kernel : iter->second) { + kernel->AddOutKernel(find_kernel); + } + } + } } } @@ -1370,7 +1393,7 @@ int Scheduler::ConstructControlFlowMainGraph(std::vector * } } auto cur_subgraph_type = PartialSubGraphType(main_graph_kernels); - auto subgraph_kernel = CreateSubGraphKernel(main_graph_kernels, nullptr, nullptr, cur_subgraph_type); + auto subgraph_kernel = CreateSubGraphKernel(main_graph_kernels, nullptr, nullptr, cur_subgraph_type, *context_); if (subgraph_kernel == nullptr) { MS_LOG(ERROR) << "create main graph for control flow model failed."; return RET_ERROR; diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index 5d2b0143c42..88217d67b7a 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -68,8 +68,6 @@ class Scheduler { kernel::LiteKernel **kernel); int FindGpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); - int FindNpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, - OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); int FindProviderKernel(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel); @@ -92,13 +90,6 @@ class Scheduler { int ConstructSubGraphs(std::vector src_kernel, std::vector *dst_kernel, std::map *sinked_kernel_map); // create subgraph_kernel from a vector of kernel - kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector &kernels, - const std::vector *in_tensors, - const std::vector *out_tensors, - kernel::SubGraphType type); - bool KernelFitCurrentSubGraph(const kernel::SubGraphType subgraph_type, const kernel::LiteKernel &kernel); - std::vector FindAllSubGraphKernels( - std::vector head_kernels, std::map *sinked_kernel_map); std::vector ScheduleMainSubGraphToKernels(); kernel::LiteKernel *SchedulePartialToSubGraphKernel(const int &subgraph_index); kernel::SubGraphType PartialSubGraphType(const std::vector &kernels); @@ -108,7 +99,6 @@ class Scheduler { // other methods static TypeId GetFirstFp32Fp16OrInt8Type(const std::vector &in_tensors); static void SetKernelTensorDataType(kernel::LiteKernel *kernel); - static kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel); int CopyPartialShapeToSubGraph(const lite::Model::Node *partial_node); int RestoreSubGraphInput(const lite::Model::Node *partial_node); bool SubGraphHasScheduled(const int &index); diff --git a/mindspore/lite/src/sub_graph_kernel.h b/mindspore/lite/src/sub_graph_kernel.h index 22d1e3874ac..0200b2ebd8b 100644 --- a/mindspore/lite/src/sub_graph_kernel.h +++ b/mindspore/lite/src/sub_graph_kernel.h @@ -175,6 +175,36 @@ class CpuFp16SubGraph : public CpuSubGraph { return CpuSubGraph::Init(); } + int Prepare() override { + auto ret = CpuSubGraph::Prepare(); + if (ret != RET_OK) { + return ret; + } + for (auto &node : this->nodes_) { + if (node->type() == schema::PrimitiveType_Cast) { + auto inputs = node->in_tensors(); + MS_ASSERT(inputs.size() >= 2); + auto dst_tensor = inputs[1]; + MS_ASSERT(dst_tensor != nullptr); + MS_ASSERT(dst_tensor->data_type() == kNumberTypeInt32); + MS_ASSERT(dst_tensor->data() != nullptr); + MS_ASSERT(dst_tensor->ElementsNum() == 1); + auto *dst_data = reinterpret_cast(dst_tensor->data()); + if (dst_data[0] == kNumberTypeFloat32) { + dst_data[0] = kNumberTypeFloat16; + } + auto outputs = node->out_tensors(); + MS_ASSERT(outputs.size() == 1); + auto output = outputs.front(); + MS_ASSERT(output != nullptr); + if (output->data_type() == kNumberTypeFloat32) { + output->set_data_type(kNumberTypeFloat16); + } + } + } + return RET_OK; + } + private: bool support_fp16_ = false; }; diff --git a/mindspore/lite/test/config/models_npu_fp16.cfg b/mindspore/lite/test/config/models_npu_fp16.cfg index 25f01036706..65465d69cf8 100644 --- a/mindspore/lite/test/config/models_npu_fp16.cfg +++ b/mindspore/lite/test/config/models_npu_fp16.cfg @@ -68,6 +68,7 @@ nasnet_mobile.tflite 1 ml_video_edit_art_transfer.onnx;3 3 ml_video_edit_enhance_update_tmp.onnx 0.5 #ml_video_edit_art_generate_20210513.onnx, output is out of range -ml_video_edit_art_transfer_20210513.onnx;3 2 +# ConstructSubgraph change, adjust threshold(3->29) for nlu temporary +ml_video_edit_art_transfer_20210513.onnx;3 29 ml_video_edit_hair_dyeing_segmodel_v2 0.5 ml_video_edit_makeup_mobilenetv203.onnx 2 diff --git a/mindspore/lite/test/st/scripts/base_functions.sh b/mindspore/lite/test/st/scripts/base_functions.sh index 8548d508b4d..f3b3ea0dff4 100644 --- a/mindspore/lite/test/st/scripts/base_functions.sh +++ b/mindspore/lite/test/st/scripts/base_functions.sh @@ -139,6 +139,7 @@ function Run_Benchmark() { if [[ ${model_name##*.} == "caffemodel" ]]; then model_name=${model_name%.*} fi + echo "Benchmarking ${model_name} $6 $7 ......" # adjust benchmark mode benchmark_mode="calib" if [[ $6 == "arm64" && $7 == "CPU" && ! ${cfg_file_name} =~ "fp16" ]]; then diff --git a/mindspore/lite/test/st/sub_graph_test.cc b/mindspore/lite/test/st/sub_graph_test.cc index 54d83ed7d2b..663b1b07282 100644 --- a/mindspore/lite/test/st/sub_graph_test.cc +++ b/mindspore/lite/test/st/sub_graph_test.cc @@ -328,7 +328,7 @@ TEST_F(SubGraphTest, RecursiveSubGraphTest) { auto &cpu_device_ctx = context.device_list_[0]; cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU; context.thread_num_ = 2; - auto session = std::shared_ptr(lite::LiteSession::CreateSession(&context)); + auto session = std::shared_ptr(session::LiteSession::CreateSession(&context)); ASSERT_NE(session, nullptr); auto ret = session->CompileGraph(model.get()); ASSERT_EQ(ret, lite::RET_OK);