diff --git a/include/api/context.h b/include/api/context.h index e6c5697b173..1c4e93d2192 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -292,9 +292,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext { /// \return The device id. uint32_t GetDeviceID() const; - inline void SetDumpConfigPath(const std::string &cfg_path); - inline std::string GetDumpConfigPath() const; - /// \brief Set AIPP configuration file path. /// /// \param[in] cfg_path AIPP configuration file path. @@ -379,9 +376,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext { inline std::string GetBufferOptimizeMode() const; private: - void SetDumpConfigPath(const std::vector &cfg_path); - std::vector GetDumpConfigPathChar() const; - void SetInsertOpConfigPath(const std::vector &cfg_path); std::vector GetInsertOpConfigPathChar() const; @@ -406,9 +400,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext { std::vector GetBufferOptimizeModeChar() const; }; -void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); } -std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); } - void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { SetInsertOpConfigPath(StringToChar(cfg_path)); } diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 52b884de930..7af8aec78ff 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1121,7 +1121,8 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph return new_parameter; } -KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { +KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs, + bool common_opt) { std::unordered_map other_graph_cnode; auto graph = NewKernelGraph(); MS_EXCEPTION_IF_NULL(graph); @@ -1161,7 +1162,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con // Update Graph Dynamic Shape Attr UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); UpdateGraphAquireGilAttr(NOT_NULL(graph)); - opt::BackendCommonOptimization(graph); + if (common_opt) { + opt::BackendCommonOptimization(graph); + } graph->SetInputNodes(); SetInputNodeUsage(graph, manager); graph->SetOptimizerFlag(); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index af9783e149f..25a7010227f 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -116,7 +116,8 @@ class SessionBasic : public std::enable_shared_from_this { bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph); - std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs, + bool common_opt = true); std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, std::vector *all_out_graph); diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index da27374fbff..4008e840e83 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -32,6 +32,7 @@ if(ENABLE_D OR ENABLE_ACL) if(NOT ENABLE_D) list(APPEND API_ACL_SRC $) + list(APPEND API_ACL_SRC $) endif() endif() diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc index 5f3c429faf6..c74570dddcb 100644 --- a/mindspore/ccsrc/cxx_api/context.cc +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -29,7 +29,6 @@ constexpr auto kModelOptionGPUTrtInferMode = "mindspore.option.gpu.trt_infer_mod constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode"; constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID; constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID; -constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path"; constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path"; constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format"; constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map"; @@ -193,16 +192,6 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const { return GetValue(data_, kModelOptionAscend310DeviceID); } -void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector &cfg_path) { - MS_EXCEPTION_IF_NULL(data_); - data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path); -} -std::vector Ascend310DeviceInfo::GetDumpConfigPathChar() const { - MS_EXCEPTION_IF_NULL(data_); - const std::string &ref = GetValue(data_, kModelOptionAscend310DumpCfgPath); - return StringToChar(ref); -} - void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc b/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc index e1ca983cdb3..bb8c7e29517 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc @@ -21,8 +21,8 @@ namespace mindspore { std::shared_ptr AclEnvGuard::global_acl_env_; std::mutex AclEnvGuard::global_acl_env_mutex_; -AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { - errno_ = aclInit(cfg_file.data()); +AclEnvGuard::AclEnvGuard() { + errno_ = aclInit(nullptr); if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) { MS_LOG(ERROR) << "Execute aclInit Failed"; return; @@ -38,18 +38,15 @@ AclEnvGuard::~AclEnvGuard() { MS_LOG(INFO) << "Acl finalize success"; } -std::shared_ptr AclEnvGuard::GetAclEnv(std::string_view cfg_file) { +std::shared_ptr AclEnvGuard::GetAclEnv() { std::shared_ptr acl_env; std::lock_guard lock(global_acl_env_mutex_); acl_env = global_acl_env_; if (acl_env != nullptr) { MS_LOG(INFO) << "Acl has been initialized, skip."; - if (!cfg_file.empty()) { - MS_LOG(WARNING) << "Dump config file option " << cfg_file << " is ignored."; - } } else { - acl_env = std::make_shared(cfg_file); + acl_env = std::make_shared(); aclError ret = acl_env->GetErrno(); if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) { MS_LOG(ERROR) << "Execute aclInit Failed"; diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.h b/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.h index 8b4ae76c68a..d5d33e6bd44 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.h +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.h @@ -23,10 +23,10 @@ namespace mindspore { class __attribute__((visibility("default"))) AclEnvGuard { public: - explicit AclEnvGuard(std::string_view cfg_file); + explicit AclEnvGuard(); ~AclEnvGuard(); aclError GetErrno() const { return errno_; } - static std::shared_ptr GetAclEnv(std::string_view cfg_file); + static std::shared_ptr GetAclEnv(); private: static std::shared_ptr global_acl_env_; diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc index d41370a996e..063961f8179 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc @@ -91,7 +91,7 @@ Status AclGraphImpl::InitEnv() { return kSuccess; } - acl_env_ = AclEnvGuard::GetAclEnv(""); + acl_env_ = AclEnvGuard::GetAclEnv(); if (acl_env_ == nullptr) { MS_LOG(ERROR) << "Acl init failed."; return kMCDeviceError; diff --git a/mindspore/ccsrc/cxx_api/graph/acl/model_process.cc b/mindspore/ccsrc/cxx_api/graph/acl/model_process.cc index 3f246dacf26..09cac8212d7 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/model_process.cc +++ b/mindspore/ccsrc/cxx_api/graph/acl/model_process.cc @@ -165,7 +165,7 @@ Status ModelProcess::InitInputsBuffer() { aclDataType data_type = aclmdlGetInputDataType(model_desc_, i); std::vector shape(dims.dims, dims.dims + dims.dimCount); const char *input_name_char = aclmdlGetInputNameByIndex(model_desc_, i); - std::string input_name = (input_name_char == nullptr) ? input_name_char : std::string(); + std::string input_name = (input_name_char != nullptr) ? input_name_char : std::string(); if (input_name.empty()) { MS_LOG(WARNING) << "Get name of input " << i << " failed."; } @@ -249,7 +249,7 @@ Status ModelProcess::InitOutputsBuffer() { aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i); std::vector shape(dims.dims, dims.dims + dims.dimCount); const char *output_name_char = aclmdlGetOutputNameByIndex(model_desc_, i); - std::string output_name = (output_name_char == nullptr) ? output_name_char : std::string(); + std::string output_name = (output_name_char != nullptr) ? output_name_char : std::string(); if (output_name.empty()) { MS_LOG(WARNING) << "Get name of output " << i << " failed."; } diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc index c6d3abf924c..ab64b4478b9 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc @@ -24,8 +24,6 @@ #include "acl/acl_base.h" namespace mindspore { -API_FACTORY_REG(ModelImpl, Ascend310, AclModel); - Status AclModel::Build() { MS_LOG(INFO) << "Start build model."; MS_EXCEPTION_IF_NULL(graph_); @@ -42,13 +40,8 @@ Status AclModel::Build() { return kSuccess; } - std::unique_ptr options = std::make_unique(model_context_); + std::shared_ptr options = std::make_shared(model_context_); MS_EXCEPTION_IF_NULL(options); - std::string dump_cfg = options->GetDumpCfgPath(); - if (!dump_cfg.empty()) { - MS_LOG(INFO) << "Options dump config file path " << dump_cfg; - (void)AclEnvGuard::GetAclEnv(dump_cfg); - } std::string options_key = options->GenAclOptionsKey(); std::shared_ptr graph; if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) { @@ -72,7 +65,7 @@ Status AclModel::Build() { } options->RenameInput(input_names); MS_EXCEPTION_IF_NULL(func_graph); - model_converter_.set_options(options.get()); + model_converter_.set_options(options); auto om_data = model_converter_.LoadMindIR(func_graph); if (om_data.Data() == nullptr || om_data.DataSize() == 0) { MS_LOG(ERROR) << "Load MindIR failed."; diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h index 5520299216f..79b93a72edb 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h @@ -47,7 +47,7 @@ class AclModel : public ModelImpl { private: ModelConverter model_converter_; - std::unique_ptr options_; + std::shared_ptr options_; std::map> dynamic_size_graph_map_; }; } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc new file mode 100644 index 00000000000..46094a23522 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc @@ -0,0 +1,475 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cxx_api/model/acl/acl_model_multi.h" +#include +#include +#include +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/session_factory.h" +#include "cxx_api/factory.h" +#include "vm/backend.h" +#include "vm/transform.h" +#include "acl/acl_rt.h" +#include "mindspore/core/load_mindir/infer_mindir.h" +#include "debug/trace.h" + +namespace mindspore { +API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti); + +namespace { +class MSTensorRef : public BaseRef { + public: + static VectorRef Convert(const std::vector &tensors) { + VectorRef res; + std::transform(tensors.begin(), tensors.end(), std::back_inserter(res), + [](const MSTensor &t) { return MSTensorRef(t); }); + return res; + } + + static std::vector Convert(const BaseRef &args) { + std::vector res; + if (utils::isa(args)) { + VectorRef args_vec = utils::cast(args); + for (size_t i = 0; i < args_vec.size(); ++i) { + const auto &item = args_vec[i]; + if (!utils::isa(item)) { + MS_LOG(EXCEPTION) << "Invalid item " << item.ToString() << " at index " << i; + } + auto wrapper = utils::cast(item); + res.push_back(wrapper.ms_tensor_); + } + } else if (utils::isa(args)) { + auto wrapper = utils::cast(args); + res.push_back(wrapper.ms_tensor_); + } else { + MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}"; + } + + return res; + } + + MS_DECLARE_PARENT(MSTensorRef, BaseRef); + explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {} + ~MSTensorRef() override = default; + + const MSTensor &GetTensor() const { return ms_tensor_; } + std::shared_ptr copy() const override { + MSTensor *tensor = ms_tensor_.Clone(); + auto res = std::make_shared(static_cast(*tensor)); + MSTensor::DestroyTensorPtr(tensor); + return res; + } + + uint32_t type() const override { return tid(); } + std::string ToString() const override { return ms_tensor_.Name(); } + bool operator==(const BaseRef &other) const override { + if (!utils::isa(other)) { + return false; + } + return *this == utils::cast(other); + } + + bool operator==(MSTensorRef &other) { + return (ms_tensor_.Name() == other.ms_tensor_.Name()) && (ms_tensor_.Shape() == other.ms_tensor_.Shape()) && + (ms_tensor_.MutableData() == other.ms_tensor_.MutableData()) && + (ms_tensor_.DataSize() == other.ms_tensor_.DataSize()) && + (ms_tensor_.DataType() == other.ms_tensor_.DataType()); + } + + private: + MSTensor ms_tensor_; +}; + +class MultiGraphAclSession : public session::SessionBasic { + public: + MultiGraphAclSession() = default; + ~MultiGraphAclSession() override = default; + void Init(uint32_t device_id) override; + GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + void RunGraph(GraphId graph_id, const std::vector &inputs, VectorRef *outputs); + void SetOptions(const std::shared_ptr &options) { options_ = options; } + + private: + std::map graphs_ = {}; + std::shared_ptr options_ = nullptr; +}; + +void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); } + +GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + class FirstGraphModeGuard { + public: + explicit FirstGraphModeGuard(const std::shared_ptr &options) : options_(options) { + if (options_ != nullptr) { + options_->SetFirstGraph(true); + } + } + ~FirstGraphModeGuard() { + if (options_ != nullptr) { + options_->SetFirstGraph(false); + } + } + + private: + std::shared_ptr options_; + }; + MS_LOG(INFO) << "Start MultiGraph Compile."; + auto kernel_graph = ConstructKernelGraph(lst, outputs, false); + MS_EXCEPTION_IF_NULL(kernel_graph); + ModelConverter model_converter_; + model_converter_.set_options(options_); + FirstGraphModeGuard guard(options_); + auto om_data = model_converter_.LoadMindIR(kernel_graph); + if (om_data.Data() == nullptr || om_data.DataSize() == 0) { + MS_LOG(ERROR) << "Load MindIR failed."; + return kMCFailed; + } + std::shared_ptr graph = std::make_shared(std::make_shared(om_data, ModelType::kOM)); + MS_EXCEPTION_IF_NULL(graph); + graphs_[kernel_graph->graph_id()] = GraphCell(graph); + MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id(); + return kernel_graph->graph_id(); +} + +void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector &inputs, VectorRef *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + MS_LOG(INFO) << "Start run graph " << graph_id; + auto iter = graphs_.find(graph_id); + if (iter == graphs_.end()) { + MS_LOG(EXCEPTION) << "Graph id " << graph_id << " not found."; + } + std::vector out_tensors; + auto ret = iter->second.Run(inputs, &out_tensors); + if (ret != kSuccess) { + MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed."; + } + (*outputs) = MSTensorRef::Convert(out_tensors); +} + +class AclBackend : public compile::MsBackend { + public: + AclBackend(const std::string &name, const std::string &target, const std::shared_ptr &options) + : MsBackend(name, target, options->GetDeviceID()) { + auto session = std::dynamic_pointer_cast(MsBackend::target_sess_); + MS_EXCEPTION_IF_NULL(session); + session->SetOptions(options); + } + + ~AclBackend() override = default; + + VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) override { + std::vector inputs; + for (const auto &arg : args) { + if (!utils::isa(arg)) { + MS_LOG(EXCEPTION) << "Invalid item " << arg.ToString(); + } + auto wrapper = utils::cast(arg); + inputs.emplace_back(wrapper.GetTensor()); + } + + VectorRef outputs; + MS_EXCEPTION_IF_NULL(target_sess_); + auto exec_sess = std::dynamic_pointer_cast(target_sess_); + MS_EXCEPTION_IF_NULL(exec_sess); + exec_sess->RunGraph(g, inputs, &outputs); + return outputs; + } + + bool GetCond(const BaseRef &c, bool *value) override { + MS_EXCEPTION_IF_NULL(value); + if (!utils::isa(c)) { + MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef."; + return false; + } + auto wrapper = utils::cast(c); + if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) { + MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool."; + return false; + } + auto data = wrapper.GetTensor().Data(); + if (data == nullptr) { + return false; + } + (*value) = *reinterpret_cast(data.get()); + return true; + } + + bool GetIndex(const BaseRef &c, int64_t *value) override { + MS_EXCEPTION_IF_NULL(value); + if (!utils::isa(c)) { + MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef."; + return false; + } + + auto wrapper = utils::cast(c); + if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) { + auto data = wrapper.GetTensor().Data(); + if (data == nullptr) { + return false; + } + auto value_int32 = *reinterpret_cast(data.get()); + (*value) = static_cast(value_int32); + return true; + } else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) { + auto data = wrapper.GetTensor().Data(); + if (data == nullptr) { + return false; + } + (*value) = *reinterpret_cast(data.get()); + return true; + } else { + MS_LOG(ERROR) << "Index must be Int type."; + return false; + } + } +}; + +class AclCompileGraph : public compile::CompileGraph { + public: + explicit AclCompileGraph(const std::shared_ptr &backend, + const std::vector &cut_list) + : CompileGraph(backend, cut_list) {} + ~AclCompileGraph() override = default; + + void AddInst(const compile::Instruction &inst, const MSTensorRef &arg) { + VectorRef args; + args.push_back(arg); + compile::CompileGraph::AddInst(inst, args); + } + + int64_t Ref(const AnfNodePtr &node) override { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_; + if (slots_.count(node) == 0 && node->isa()) { + if (IsValueNode(node)) { + MS_LOG(DEBUG) << "Push graph."; + compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node)); + } else { + MS_LOG(DEBUG) << "Push."; + if (IsValueNode(node)) { + MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } else if (IsValueNode(node)) { + auto tensor_node = std::dynamic_pointer_cast(node->cast()->value()); + MS_EXCEPTION_IF_NULL(tensor_node); + std::string name = ""; + std::vector shape = tensor_node->shape_c(); + DataType type = static_cast(tensor_node->data_type_c()); + auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size()); + MSTensorRef mstensor_ref(*mstensor_node); + AddInst(compile::Instruction::kPush, mstensor_ref); + MSTensor::DestroyTensorPtr(mstensor_node); + } else { + compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node)); + } + } + Push(node); + } + MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node] + << ", return: " << slots_[node] - height_; + return slots_[node] - height_; + } +}; + +class AclCompileGraphs : public compile::CompileGraphs { + public: + explicit AclCompileGraphs(const std::shared_ptr &backend, + const std::vector &cut_list) + : CompileGraphs(backend, cut_list) { + MS_EXCEPTION_IF_NULL(backend); + MS_LOG(DEBUG) << "Start vm: " << backend->name(); + transform_ = std::make_shared(backend, cut_list); + Reset(); + } + ~AclCompileGraphs() override = default; + void Compile(const FuncGraphPtr &graph) override { + MS_LOG(DEBUG) << "Start"; + mapping_[graph] = SizeToLong(insts_.size()); + if (transform_ != nullptr) { + auto insts = transform_->Run(graph, false); + if (!insts.empty()) { + (void)insts_.insert(insts_.end(), insts.begin(), insts.end()); + } + } + MS_LOG(DEBUG) << "End"; + } +}; + +std::shared_ptr CreateBackend(const std::shared_ptr &options) { + MS_EXCEPTION_IF_NULL(options); + return std::make_shared(kMsConvert, kDavinciMultiGraphInferenceDevice, options); +} + +bool HasMultiGraph(const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(fg); + std::vector all_nodes = TopoSort(fg->get_return()); + for (const auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (IsValueNode(node)) { + MS_LOG(INFO) << fg->ToString() << " has FuncGraph node " << node->DebugString() << " is multi graph."; + return true; + } + } + return false; +} +} // namespace +Status AclModelMulti::Build() { + if (!is_multi_graph_.has_value()) { + is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); + } + + if (!is_multi_graph_.value()) { + return AclModel::Build(); + } + + if (vm_ != nullptr) { + MS_LOG(INFO) << "Multi graph model has been built, skip."; + return kSuccess; + } + MS_LOG(INFO) << "Start build multi graph model."; + // perpare func graph + auto manager = MakeManager(); + manager->AddFuncGraph(ModelImpl::GetFuncGraph()); + ModelImpl::GetFuncGraph()->set_manager(manager); + // set inputs + SetInputs(); + // infer mindir + abstract::AbstractBasePtrList broaded_args; + auto fg = ModelImpl::GetFuncGraph(); + MS_EXCEPTION_IF_NULL(fg); + const auto &inputs = fg->get_inputs(); + (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(broaded_args), + [](const AnfNodePtr &n) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(n); + auto abstract = n->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + if (abstract->GetValueTrack() != kAnyValue) { + return abstract->Broaden(); + } + return abstract; + }); + (void)InferMindir(ModelImpl::GetFuncGraph(), broaded_args); + // create vm + auto backend = CreateBackend(std::make_shared(model_context_)); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + backend->set_is_multi_graph_sink(false); + context_ptr->set_param(MS_CTX_DEVICE_TARGET, kDavinciMultiGraphInferenceDevice); + context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); + context_ptr->set_param(MS_CTX_ENABLE_LOOP_SINK, false); + auto compile = std::make_shared(backend, compile::GetMsNonlinearOps()); + + vm_ = compile->CompileAndLink(ModelImpl::GetFuncGraph()); + backend_ = std::move(backend); + MS_LOG(INFO) << "Build multi graph model success."; + return kSuccess; +} + +Status AclModelMulti::Predict(const std::vector &inputs, std::vector *outputs) { + if (!is_multi_graph_.has_value()) { + is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); + } + + if (!is_multi_graph_.value()) { + return AclModel::Predict(inputs, outputs); + } + + Build(); + MS_LOG(INFO) << "Start predict multi graph model."; + MS_EXCEPTION_IF_NULL(vm_); + MS_EXCEPTION_IF_NULL(outputs); + try { + (*outputs) = MSTensorRef::Convert(vm_->Eval(MSTensorRef::Convert(inputs))); + } catch (const std::exception &ex) { + MS_LOG(ERROR) << "Predict Failed, error: " << ex.what(); + return kMCFailed; + } + + if (inputs.size() != inputs_.size() && !inputs_.empty() != 0) { + MS_LOG(ERROR) << "Input Size is wrong."; + return kMCFailed; + } + + if (inputs_.empty()) { + inputs_ = inputs; + } else { + for (size_t i = 0; i < inputs_.size(); ++i) { + auto input_tensor = MSTensor::CreateTensor(inputs_[i].Name(), inputs_[i].DataType(), inputs_[i].Shape(), + inputs[i].Data().get(), inputs[i].DataSize()); + inputs_[i] = (*input_tensor); + MSTensor::DestroyTensorPtr(input_tensor); + } + } + + outputs_ = *outputs; + MS_LOG(INFO) << "Predict multi graph model success."; + return kSuccess; +} + +void AclModelMulti::SetInputs() { + if (inputs_.empty()) { + auto fg = ModelImpl::GetFuncGraph(); + MS_EXCEPTION_IF_NULL(fg); + const auto &inputs = fg->get_inputs(); + for (const auto &in : inputs) { + auto input_param = std::dynamic_pointer_cast(in); + MS_EXCEPTION_IF_NULL(input_param); + MS_EXCEPTION_IF_NULL(input_param->abstract()); + auto input_value = input_param->abstract()->GetValueTrack(); + auto tensor = input_value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + + std::vector shape = tensor->shape_c(); + auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast(tensor->data_type_c()), + shape, nullptr, tensor->Size()); + inputs_.emplace_back(*input_tensor); + MSTensor::DestroyTensorPtr(input_tensor); + } + } else { + MS_LOG(DEBUG) << "inputs_ has been set."; + } +} + +std::vector AclModelMulti::GetInputs() { + if (!is_multi_graph_.has_value()) { + is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); + } + + if (!is_multi_graph_.value()) { + return AclModel::GetInputs(); + } + + return inputs_; +} + +std::vector AclModelMulti::GetOutputs() { + if (!is_multi_graph_.has_value()) { + is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); + } + + if (!is_multi_graph_.value()) { + return AclModel::GetOutputs(); + } + + return outputs_; +} + +namespace session { +MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession); +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.h new file mode 100644 index 00000000000..11f177f5b09 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.h @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H +#define MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H + +#include "cxx_api/model/acl/acl_model.h" +#include +#include +#include +#include + +namespace mindspore { +namespace compile { +class MsBackend; +class FinalVM; +} // namespace compile + +class AclModelMulti : public AclModel { + public: + AclModelMulti() : AclModel(), is_multi_graph_(std::nullopt) {} + ~AclModelMulti() = default; + + Status Build() override; + Status Predict(const std::vector &inputs, std::vector *outputs) override; + + std::vector GetInputs() override; + std::vector GetOutputs() override; + + private: + void SetInputs(); + + std::optional is_multi_graph_; + std::shared_ptr backend_; + std::shared_ptr vm_; + std::vector inputs_ = {}; + std::vector outputs_ = {}; +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc index c83ca6af477..a5fd72066f7 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "cxx_api/model/acl/acl_model_options.h" +#include #include #include "utils/log_adapter.h" #include "external/ge/ge_api_types.h" @@ -54,7 +55,6 @@ AclModelOptions::AclModelOptions(const std::shared_ptr &context) { op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode(); fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath(); device_id_ = ascend310_info->GetDeviceID(); - dump_cfg_path_ = ascend310_info->GetDumpConfigPath(); buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode(); const char *soc_name = aclrtGetSocName(); if (soc_name == nullptr) { @@ -98,6 +98,14 @@ std::tuple, std::map first_graph_options = { + ge::ir_option::INSERT_OP_FILE, + ge::ir_option::INPUT_FORMAT, + ge::ir_option::INPUT_SHAPE, + }; + + const std::set multi_graph_unsupported_options = {ge::ir_option::OUTPUT_TYPE}; + std::map init_options; std::map build_options; for (auto [ms_option, acl_option_key] : init_options_map) { @@ -115,6 +123,20 @@ std::tuple, std::map #include #include +#include #include "include/api/types.h" #include "include/api/status.h" #include "include/api/context.h" @@ -32,11 +33,11 @@ class AclModelOptions { ~AclModelOptions() = default; std::string GenAclOptionsKey() const; uint32_t GetDeviceID() const { return device_id_; } - std::string GetDumpCfgPath() const { return dump_cfg_path_; } - void RenameInput(const std::vector &name); + void RenameInput(const std::vector &); // return tuple std::tuple, std::map> GenAclOptions() const; + void SetFirstGraph(bool is_first_graph) { first_graph_flag_ = is_first_graph; } private: std::string output_node_; // todo: at convert.cc::BuildGraph(), no atc options @@ -55,7 +56,7 @@ class AclModelOptions { std::map> input_shape_map_; // other options uint32_t device_id_; - std::string dump_cfg_path_; + std::optional first_graph_flag_; }; } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc index e25809bfe3f..94ba23c9066 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc @@ -226,8 +226,9 @@ Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) { std::map init_options; std::map build_options; - if (options_ != nullptr) { - std::tie(init_options, build_options) = options_->GenAclOptions(); + auto option = options_.lock(); + if (option != nullptr) { + std::tie(init_options, build_options) = option->GenAclOptions(); } return BuildAirModel(df_graph, init_options, build_options); diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.h b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h index e9652a10665..dfd54fbded4 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_converter.h +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h @@ -30,12 +30,12 @@ namespace mindspore { class ModelConverter { public: - ModelConverter() : options_(nullptr) {} + ModelConverter() : options_() {} ~ModelConverter() = default; Buffer LoadMindIR(const FuncGraphPtr &func_graph); - void set_options(AclModelOptions *options) { options_ = options; } + void set_options(const std::weak_ptr &options) { options_ = options; } private: transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph); @@ -43,7 +43,7 @@ class ModelConverter { const std::map &build_options); Buffer LoadAscendIRInner(const Buffer &model_data); - AclModelOptions *options_; + std::weak_ptr options_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H diff --git a/mindspore/ccsrc/minddata/dataset/core/ascend_resource.cc b/mindspore/ccsrc/minddata/dataset/core/ascend_resource.cc index e896f983d39..cdae60c242e 100644 --- a/mindspore/ccsrc/minddata/dataset/core/ascend_resource.cc +++ b/mindspore/ccsrc/minddata/dataset/core/ascend_resource.cc @@ -24,7 +24,6 @@ namespace dataset { Status AscendResource::InitResource(uint32_t device_id) { ResourceInfo resource; - resource.aclConfigPath = ""; resource.deviceIds.insert(device_id); ascend_resource_ = ResourceManager::GetInstance(); APP_ERROR ret = ascend_resource_->InitResource(resource); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_crop_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_crop_jpeg_op.cc index 9f6f88267af..20d0d8932a8 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_crop_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_crop_jpeg_op.cc @@ -77,7 +77,6 @@ Status DvppCropJpegOp::Compute(const std::shared_ptr &input, std::shared imageinfo.heightStride = yuv_shape_[3]; imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; ResourceInfo resource; - resource.aclConfigPath = ""; resource.deviceIds.insert(0); std::shared_ptr instance = ResourceManager::GetInstance(); APP_ERROR ret = instance->InitResource(resource); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc index f3f827ae36a..8452c00d9ef 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc @@ -70,7 +70,6 @@ Status DvppDecodeJpegOp::Compute(const std::shared_ptr &input, std::shar imageInfo.lenOfByte = filesize; imageInfo.data = static_cast(buffer); ResourceInfo resource; - resource.aclConfigPath = ""; resource.deviceIds.insert(0); std::shared_ptr instance = ResourceManager::GetInstance(); APP_ERROR ret = instance->InitResource(resource); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc index b3a6cbda30f..702323d98d5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc @@ -68,7 +68,6 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr &input, std::share imageInfo.lenOfByte = filesize; imageInfo.data = static_cast(buffer); ResourceInfo resource; - resource.aclConfigPath = ""; resource.deviceIds.insert(0); std::shared_ptr instance = ResourceManager::GetInstance(); APP_ERROR ret = instance->InitResource(resource); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc index 4fd7b1b9ef1..d950a980d90 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc @@ -69,7 +69,6 @@ Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr &input, imageInfo.lenOfByte = filesize; imageInfo.data = static_cast(buffer); ResourceInfo resource; - resource.aclConfigPath = ""; resource.deviceIds.insert(0); std::shared_ptr instance = ResourceManager::GetInstance(); APP_ERROR ret = instance->InitResource(resource); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_jpeg_op.cc index cf39305745f..67fbfa3316d 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_jpeg_op.cc @@ -68,7 +68,6 @@ Status DvppDecodeResizeJpegOp::Compute(const std::shared_ptr &input, std imageInfo.lenOfByte = filesize; imageInfo.data = static_cast(buffer); ResourceInfo resource; - resource.aclConfigPath = ""; resource.deviceIds.insert(0); std::shared_ptr instance = ResourceManager::GetInstance(); APP_ERROR ret = instance->InitResource(resource); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_resize_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_resize_jpeg_op.cc index c2b093fdab5..4732eadbe4e 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_resize_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_resize_jpeg_op.cc @@ -78,7 +78,6 @@ Status DvppResizeJpegOp::Compute(const std::shared_ptr &input, std::shar imageinfo.heightStride = yuv_shape_[3]; imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; ResourceInfo resource; - resource.aclConfigPath = ""; resource.deviceIds.insert(0); std::shared_ptr instance = ResourceManager::GetInstance(); APP_ERROR ret = instance->InitResource(resource); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc index b80126ef54d..7dbeb4b6924 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc @@ -85,22 +85,11 @@ APP_ERROR ResourceManager::InitResource(ResourceInfo &resourceInfo) { MS_LOG(INFO) << "Acl has been initialized, skip."; return APP_ERR_OK; } - std::string &aclConfigPath = resourceInfo.aclConfigPath; APP_ERROR ret = APP_ERR_OK; - if (aclConfigPath.length() == 0) { - // Init acl without aclconfig - acl_env_ = mindspore::AclEnvGuard::GetAclEnv(""); - } else { - ret = ExistFile(aclConfigPath); - if (ret != APP_ERR_OK) { - MS_LOG(ERROR) << "Acl config file not exist, ret = " << ret << "."; - return ret; - } - acl_env_ = mindspore::AclEnvGuard::GetAclEnv(aclConfigPath); - } + acl_env_ = mindspore::AclEnvGuard::GetAclEnv(); if (acl_env_ == nullptr) { MS_LOG(ERROR) << "Failed to init acl."; - return ret; + return APP_ERR_COMM_FAILURE; } std::copy(resourceInfo.deviceIds.begin(), resourceInfo.deviceIds.end(), std::back_inserter(deviceIds_)); MS_LOG(INFO) << "Initialized acl successfully."; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h index daed1f9faed..467f7b64522 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h @@ -53,7 +53,6 @@ struct DeviceResInfo { struct ResourceInfo { std::set deviceIds; - std::string aclConfigPath; std::string singleOpFolderPath; std::unordered_map deviceResInfos; // map }; diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 279295e8a63..0c24492553b 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -1872,7 +1872,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) { MS_LOG(WARNING) << "set attr value for const failed"; } -#if (defined ENABLE_GE) auto const_op = std::static_pointer_cast(op); if (const_op == nullptr) { MS_LOG(ERROR) << "Get Constant operator failed"; @@ -1881,7 +1880,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) { auto ge_tensor = const_op->get_attr_value(); auto ge_desc = ge_tensor.GetTensorDesc(); (void)const_op->update_output_desc_y(ge_desc); -#endif op_cache_[node.get()] = op; return op_cache_[node.get()]; diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter.cc index 2fe0f297143..c93fad4cf0c 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter.cc @@ -317,11 +317,6 @@ size_t OpAdapterImpl::GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { std::shared_ptr OpAdapterImpl::CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, const std::string &format) { - if (shape_ptr == nullptr) { - MS_LOG(ERROR) << "Shape ptr is nullptr"; - return nullptr; - } - if (type == nullptr) { MS_LOG(ERROR) << "Type ptr is nullptr"; return nullptr; @@ -331,8 +326,8 @@ std::shared_ptr OpAdapterImpl::CreateOutputDesc(const abstract::Sh if (kObjectTypeTensorType == me_type) { me_type = dyn_cast(type)->element()->type_id(); } - auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format); - return desc; + + return TransformUtil::GetGeTensorDesc((shape_ptr == nullptr) ? ShapeVector{} : shape_ptr->shape(), me_type, format); } Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, @@ -472,7 +467,7 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base return; } MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Op name is " << op->GetName(); + MS_LOG(INFO) << "Op name is " << op->GetName() << " anf is " << node->DebugString(); auto normal_shape_ptr = dyn_cast(shp); auto no_shape_ptr = dyn_cast(shp); diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index ba6e025c99e..1a3470f942a 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -78,7 +78,7 @@ class MsBackend : public Backend { ~MsBackend() override = default; LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = ""); - VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); + virtual VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); VectorRef MsSimuRunGraph(const GraphId &g); GraphId CompileGraph(NotNull fg) override; @@ -90,7 +90,7 @@ class MsBackend : public Backend { void SetDebugger() override; #endif - private: + protected: session::SessionPtr target_sess_; session::SessionPtr other_sess_; std::string target_device_; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 3f3f670859e..3d518cb6610 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -145,6 +145,22 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { } } +void CompileGraph::PushInputs(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector parameters = graph->parameters(); + for (size_t i = parameters.size(); i != 0; i--) { + MS_EXCEPTION_IF_NULL(parameters[i - 1]); + auto param = parameters[i - 1]->cast(); + MS_EXCEPTION_IF_NULL(param); + if (param->has_default()) { + MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip."; + continue; + } + Push(param); + MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(true); + } +} + int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) { MS_EXCEPTION_IF_NULL(segment); MS_LOG(DEBUG) << "LinConvert start"; @@ -257,11 +273,16 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) { return true; } -InstSet CompileGraph::Run(const FuncGraphPtr &graph) { +InstSet CompileGraph::Run(const FuncGraphPtr &graph, bool push_weight) { MS_EXCEPTION_IF_NULL(graph); Reset(); - PushParameters(graph); + if (push_weight) { + PushParameters(graph); + } else { + PushInputs(graph); + } + int64_t param_height = height_; MS_EXCEPTION_IF_NULL(graph->get_return()); MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 4f0e19f422f..da7d853666f 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -53,14 +53,14 @@ class CompileGraph { public: explicit CompileGraph(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); - ~CompileGraph() = default; + virtual ~CompileGraph() = default; - InstSet Run(const FuncGraphPtr &func_graph); + InstSet Run(const FuncGraphPtr &func_graph, bool push_weight = true); bool IsCut(const AnfNodePtr &node); void Push(const AnfNodePtr &node); void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } void Ret(int64_t nargs); - int64_t Ref(const AnfNodePtr &node); + virtual int64_t Ref(const AnfNodePtr &node); void set_height(int64_t h) { height_ = h; @@ -76,8 +76,9 @@ class CompileGraph { inst_.clear(); } - private: + protected: void PushParameters(const FuncGraphPtr &func_graph); + void PushInputs(const FuncGraphPtr &graph); bool Compile(const FuncGraphPtr &func_graph); int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = ""); int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); @@ -114,18 +115,18 @@ class CompileGraphs { public: explicit CompileGraphs(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); - ~CompileGraphs() = default; + virtual ~CompileGraphs() = default; void Reset() { insts_.clear(); mapping_.clear(); } - void Compile(const FuncGraphPtr &func_graph); + virtual void Compile(const FuncGraphPtr &func_graph); FinalVMPtr Link(); FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); - private: + protected: InstSet insts_; std::unordered_map mapping_; CompileGraphPtr transform_; diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 9c24080c4e0..3e9dccfd75c 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -424,17 +424,11 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape(); auto out_shape = BroadcastShape(x_shape, y_shape); - if (out_shape.empty()) { - MS_LOG(EXCEPTION) << "Less op BroadcastShape fail: " << args_spec_list[0]->ToString() << "," - << args_spec_list[1]->ToString(); - } auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min); auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max); - auto output_type = std::make_shared(); - auto ret = - std::make_shared(output_type, std::make_shared(out_shape, out_shape_min, out_shape_max)); - return ret; + return std::make_shared(output_type, + std::make_shared(out_shape, out_shape_min, out_shape_max)); } } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/load_mindir/infer_mindir.cc b/mindspore/core/load_mindir/infer_mindir.cc new file mode 100644 index 00000000000..fe10bc01f52 --- /dev/null +++ b/mindspore/core/load_mindir/infer_mindir.cc @@ -0,0 +1,503 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "load_mindir/infer_mindir.h" +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "abstract/abstract_function.h" +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace { +class MindIREngine { + public: + explicit MindIREngine(const FuncGraphPtr &root) : func_graph_(root), nodeuser_map_(root->manager()->node_users()) {} + ~MindIREngine() = default; + MindIREngine(const MindIREngine &) = delete; + MindIREngine &operator=(const MindIREngine &) = delete; + + bool InferShape(const AbstractBasePtrList &args); + + private: + using AbstractBasePtrListPtr = std::shared_ptr; + + void Init(const AbstractBasePtrList &args); + static AbstractBasePtr InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); + void EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args); + void EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args); + void EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node); + void InferParameter(const AnfNodePtr &node); + void InferValueNode(const AnfNodePtr &node); + void InferCNode(const AnfNodePtr &node); + void EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &abstractFunc, const CNodePtr &node, + const AbstractBasePtrListPtr &args); + void EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node, + const AbstractBasePtrListPtr &args); + void EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node, + const AbstractBasePtrListPtr &args); + void EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node, + const AbstractBasePtrListPtr &args); + bool CheckCNodeNotReady(const CNodePtr &node); + void UpdateReady(const AnfNodePtr &node); + void SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result); + AbstractBasePtr GetCNodeOperatorAbstract(const AnfNodePtr &node); + + FuncGraphPtr func_graph_; + std::map node_input_depends_; + std::map infer_resut_; + std::map func_graph_result_; + std::map> func_graph_visited_; + std::deque ready_; + std::set todo_; + NodeUsersMap nodeuser_map_; +}; + +// Infer the root function graph. +bool MindIREngine::InferShape(const AbstractBasePtrList &args) { + Init(args); + while (!ready_.empty()) { + auto current = ready_.front(); + MS_EXCEPTION_IF_NULL(current); + ready_.pop_front(); + if (current->isa()) { + InferCNode(current); + } else if (current->isa()) { + InferValueNode(current); + } else if (current->isa()) { + InferParameter(current); + } else { + MS_LOG(WARNING) << " There is something changed. Please check the code."; + } + } + + // Set abstract of node. + for (const auto &item : infer_resut_) { + item.first->set_abstract(item.second); + } + + if (todo_.empty()) { + MS_LOG(DEBUG) << "Finish to Infere."; + return true; + } + MS_LOG(WARNING) << "Not finished to infer: " << todo_.size(); + for (const auto &node : todo_) { + MS_LOG(DEBUG) << "Node uninfered: " << node->DebugString(); + } + return false; +} + +void MindIREngine::Init(const AbstractBasePtrList &args) { + MS_EXCEPTION_IF_NULL(func_graph_); + auto manager = func_graph_->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (const auto &node : manager->all_nodes()) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + todo_.insert(node); + node_input_depends_[node] = cnode->inputs().size(); + } else if (node->isa()) { + auto param = node->cast(); + MS_EXCEPTION_IF_NULL(param); + if (param->has_default()) { + node_input_depends_[node] = 0; + infer_resut_[node] = param->default_param()->ToAbstract(); + ready_.push_back(node); + } else { + node_input_depends_[node] = 1; + todo_.insert(node); + } + } else { + // Value Node + node_input_depends_[node] = 0; + ready_.push_back(node); + } + } + + auto inputs = func_graph_->get_inputs(); + if (inputs.size() != args.size()) { + MS_LOG(EXCEPTION) << "The input parameters is not Compatible. mindir:" << inputs.size() + << " inputs: " << args.size() << " FuncGraph:" << func_graph_->ToString(); + } + // Root Func Parameters + for (size_t i = 0; i < args.size(); ++i) { + this->SaveNodeInferResult(inputs[i], args[i]); + } + MS_LOG(DEBUG) << "Finish init. Size of nodes:" << manager->all_nodes().size(); +} + +// Infer primitive using C++ implement. +AbstractBasePtr MindIREngine::InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(prim); + try { + static auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); + auto ret = prim_eval_implement_map.find(prim); + if (ret != prim_eval_implement_map.end()) { + // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr + MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_); + return ret->second.infer_shape_impl_(nullptr, prim, args_spec_list); + } else { + // if the infer function has been not founded in the front infer map find it in the backend infer map instead + static auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); + auto ret_backend = prim_backend_eval_impl_map.find(prim); + if (ret_backend != prim_backend_eval_impl_map.end()) { + MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_); + return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list); + } + } + MS_LOG(WARNING) << "Get infer shape function failed, primitive name:" << prim->name() + << " primitive type:" << prim->type_name() << " It will keep the prevalue witch danger."; + } catch (const std::exception &ex) { + MS_LOG(WARNING) << "Catch primitive:" << prim->ToString() << " InferPrimitiveShape exception:" << ex.what() + << " It will keep the prevalue witch danger."; + } + return nullptr; +} + +void MindIREngine::EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, + const AbstractBasePtrListPtr &args) { + AbstractBasePtrList args_spec_list; + // Args has been resolved by partial + if (args != nullptr) { + args_spec_list.insert(args_spec_list.end(), args->begin(), args->end()); + } else { + (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(args_spec_list), + [this](const AnfNodePtr &arg) { return infer_resut_[arg]; }); + } + + // Call C++ infer + auto result = InferPrimitiveShape(prim, args_spec_list); + if (result == nullptr) { + MS_LOG(INFO) << node->ToString() + << " can't be inferred shape. It will keep the prevalue witch danger. Prim: " << prim->ToString(); + result = node->abstract(); + } + SaveNodeInferResult(node, result); +} + +void MindIREngine::EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node) { + if (node->inputs().size() < 2) { + MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2"; + } + auto result = infer_resut_[node->inputs()[1]]; + auto funcName = node->func_graph()->ToString(); + auto it = func_graph_result_.find(funcName); + if (it != func_graph_result_.end()) { + result = result->Join(it->second); + } + this->func_graph_result_[funcName] = result; + SaveNodeInferResult(node, result); + MS_LOG(DEBUG) << funcName << " result: " << result->ToString(); + + // Set the result of the node whose Operator is this funcGraph + for (const auto &item : func_graph_visited_[funcName]) { + SaveNodeInferResult(item, result); + } +} + +void MindIREngine::EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node, + const AbstractBasePtrListPtr &args) { + // Args has been resolved + if (args != nullptr) { + if (args->size() < 2) { + MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2"; + } + auto real_func = (*args)[0]->cast(); + if (real_func == nullptr) { + MS_LOG(EXCEPTION) << (*args)[0]->ToString() << " is not a function abstract."; + } + AbstractBasePtrList partial_args_list; + partial_args_list.insert(partial_args_list.end(), args->begin() + 1, args->end()); + auto partial_func = std::make_shared(real_func, partial_args_list, node); + SaveNodeInferResult(node, partial_func); + return; + } + // Not Resolved. + if (node->inputs().size() < 2) { + MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2"; + } + auto &func = infer_resut_[node->inputs()[1]]; + auto real_func = func->cast(); + if (real_func == nullptr) { + MS_LOG(EXCEPTION) << func->ToString() << " is not a function abstract."; + } + AbstractBasePtrList partial_args_list; + (void)std::transform(node->inputs().begin() + 2, node->inputs().end(), std::back_inserter(partial_args_list), + [this](const AnfNodePtr &arg) { return infer_resut_[arg]; }); + auto partial_func = std::make_shared(real_func, partial_args_list, node); + SaveNodeInferResult(node, partial_func); +} + +void MindIREngine::EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node, + const AbstractBasePtrListPtr &args) { + AbstractBasePtrListPtr partial_args_list = std::make_shared(); + // Join arguments in partial and the rest arguments from args_conf_list. + auto func_args = func->args(); + partial_args_list->insert(partial_args_list->end(), func_args.begin(), func_args.end()); + if (args == nullptr) { + // Not Recursive + (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(*partial_args_list), + [this](const AnfNodePtr &arg) { return infer_resut_[arg]; }); + } else { + // Recursive + partial_args_list->insert(partial_args_list->end(), args->begin(), args->end()); + } + + // Get real function + abstract::AbstractFuncAtomPtrList abstractFuncList; + auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) { + abstractFuncList.push_back(poss); + }; + func->fn()->Visit(build_fuction); + for (const auto &abstractFunc : abstractFuncList) { + EvalAbstractFunction(abstractFunc, node, partial_args_list); + } +} + +void MindIREngine::SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result) { + auto answer = result; + auto it = infer_resut_.find(node); + if (it != infer_resut_.end()) { + MS_LOG(DEBUG) << node->ToString() << " result: " << it->second->ToString(); + answer = result->Join(it->second); + if (*answer == *(it->second)) { + MS_LOG(DEBUG) << node->ToString() << " The value is not changed."; + return; + } + } + MS_LOG(DEBUG) << node->ToString() << " result: " << answer->ToString(); + infer_resut_[node] = answer; + UpdateReady(node); +} + +void MindIREngine::EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node, + const AbstractBasePtrListPtr &args) { + auto prim = func->prim(); + // Return Primitive + if (prim->name() == prim::kPrimReturn->name()) { + EvalReturnPrimitive(prim, node); + return; + } + // Partial Primitive + if (prim->name() == prim::kPrimPartial->name()) { + EvalPartialPrimitive(prim, node, args); + return; + } + // common Primitive + EvalCommonPrimitive(prim, node, args); +} + +bool MindIREngine::CheckCNodeNotReady(const CNodePtr &node) { + int depend = 0; + for (const auto &input : node->inputs()) { + depend += infer_resut_.find(input) != infer_resut_.end() ? 0 : 1; + } + this->node_input_depends_[node] = depend; + return depend; +} + +void MindIREngine::EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node, + const AbstractBasePtrListPtr &args) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func); + MS_EXCEPTION_IF_NULL(func->func_graph()); + // Has Processd + MS_LOG(DEBUG) << node->ToString() << " FuncGraph: " << func->ToString(); + auto funcName = func->func_graph()->ToString(); + auto it = func_graph_result_.find(funcName); + if (it != func_graph_result_.end()) { + MS_LOG(DEBUG) << "The abstract of " << node->ToString() << " = abstract of " << func->ToString(); + SaveNodeInferResult(node, it->second); + + // Process only one return valueNode function graph + auto func_inputs = func->func_graph()->get_inputs(); + // args has been resolved in partial. + if (args != nullptr) { + if (func_inputs.size() != args->size()) { + MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size() + << " CNode:" << node->DebugString() << " input size:" << args->size(); + } + for (size_t i = 0; i < func_inputs.size(); ++i) { + infer_resut_[func_inputs[i]] = + (*args)[i]; // Not use SaveNodeInferResult because this function has been evaluated. + todo_.erase(func_inputs[i]); + } + return; + } + // args is not resolved. + auto &cnode_inputs = node->inputs(); + if (func_inputs.size() != cnode_inputs.size() - 1) { + MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size() + << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size(); + } + for (size_t i = 0; i < func_inputs.size(); ++i) { + infer_resut_[func_inputs[i]] = infer_resut_[cnode_inputs[i + 1]]; + todo_.erase(func_inputs[i]); + } + return; + } + + // Be handling + auto visitIt = func_graph_visited_.find(funcName); + if (visitIt != func_graph_visited_.end()) { + visitIt->second.insert(node); + return; + } + func_graph_visited_[funcName] = std::set({node}); + + // Call the funcGraph + auto func_inputs = func->func_graph()->get_inputs(); + + // args has been resolved in partial. + if (args != nullptr) { + if (func_inputs.size() != args->size()) { + MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size() + << " CNode:" << node->DebugString() << " input size:" << args->size(); + } + for (size_t i = 0; i < func_inputs.size(); ++i) { + SaveNodeInferResult(func_inputs[i], (*args)[i]); + } + return; + } + // args is not resolved. + auto &cnode_inputs = node->inputs(); + if (func_inputs.size() != cnode_inputs.size() - 1) { + MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size() + << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size(); + } + + for (size_t i = 0; i < func_inputs.size(); ++i) { + SaveNodeInferResult(func_inputs[i], infer_resut_[cnode_inputs[i + 1]]); + } +} + +void MindIREngine::InferParameter(const AnfNodePtr &node) { UpdateReady(node); } + +void MindIREngine::InferValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = GetValueNode(node); + MS_EXCEPTION_IF_NULL(value); + AbstractBasePtr result; + if (value->isa()) { + auto func_graph = value->cast(); + auto temp_context = abstract::AnalysisContext::DummyContext(); + result = std::make_shared(func_graph, temp_context, node); + } else if (value->isa()) { + auto prim = value->cast(); + result = std::make_shared(prim, node); + } else { + result = value->ToAbstract(); + } + + if (result->isa()) { + result = result->Broaden(); + } + SaveNodeInferResult(node, result); +} + +AbstractBasePtr MindIREngine::GetCNodeOperatorAbstract(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto op = cnode->inputs()[0]; + auto it = infer_resut_.find(op); + if (it != infer_resut_.end()) { + return it->second; + } + MS_LOG(EXCEPTION) << "Can't get the abstract of Node:" << op->DebugString(); +} + +// If args is nullPtr, it is called by InferCNode, else it is called recursively by EvalPartialAbastract. +void MindIREngine::EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node, + const AbstractBasePtrListPtr &args) { + MS_EXCEPTION_IF_NULL(func); + if (func->isa()) { + // C++ Primitive + auto prim = func->cast(); + EvalPrimitiveAbastract(prim, node, args); + } else if (func->isa()) { + // FuncGraph + auto funcGraph = func->cast(); + EvalFuncGraphAbastract(funcGraph, node, args); + } else if (func->isa()) { + // Partial + auto partialPrim = func->cast(); + EvalPartialAbastract(partialPrim, node, args); + } else { + MS_LOG(EXCEPTION) << "MindIR can't process the abstractFunc: " << func->DumpText(); + } +} + +void MindIREngine::UpdateReady(const AnfNodePtr &node) { + todo_.erase(node); + auto it = nodeuser_map_.find(node); + if (it == nodeuser_map_.end()) { + return; + } + const auto &users = it->second; + MS_LOG(DEBUG) << node->ToString() << " has users: " << users.size(); + for (const auto &user : users) { + int count = node_input_depends_[user.first]; + node_input_depends_[user.first] = count - 1; + if (count <= 1) { + ready_.push_back(user.first); + MS_LOG(DEBUG) << "Node:" << user.first->ToString() << " is ready."; + if (count < 1) { + MS_LOG(INFO) << " There is something to do. Node:" << node->ToString() << " user:" << user.first->DebugString(); + } + } + } +} + +void MindIREngine::InferCNode(const AnfNodePtr &node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (CheckCNodeNotReady(cnode)) { + MS_LOG(INFO) << "The node is not ready: " << cnode->DebugString(); + return; + } + AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode); + if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { + MS_LOG(EXCEPTION) << "EvalCNode eval Undetermined"; + } + abstract::AbstractFunctionPtr func = dyn_cast(possible_func); + if (func == nullptr) { + MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << "."; + MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code."; + } + abstract::AbstractFuncAtomPtrList abstractFuncList; + auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) { + abstractFuncList.push_back(poss); + }; + func->Visit(build_fuction); + for (const auto &abstractFunc : abstractFuncList) { + EvalAbstractFunction(abstractFunc, cnode, nullptr); + } +} +} // namespace +bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args) { + auto engine = std::make_shared(root); + return engine->InferShape(args); +} +} // namespace mindspore diff --git a/mindspore/core/load_mindir/infer_mindir.h b/mindspore/core/load_mindir/infer_mindir.h new file mode 100644 index 00000000000..debf4ea5d16 --- /dev/null +++ b/mindspore/core/load_mindir/infer_mindir.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H +#define MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H +#include "base/base.h" +#include "ir/anf.h" + +namespace mindspore { +bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args); +} // namespace mindspore + +#endif // MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index a80a346902f..4729c9cb2a1 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -51,6 +51,7 @@ const char kCPUDevice[] = "CPU"; const char kGPUDevice[] = "GPU"; const char kAscendDevice[] = "Ascend"; const char kDavinciInferenceDevice[] = "AscendInference"; +const char kDavinciMultiGraphInferenceDevice[] = "AscendMultiGraphInference"; const char kGpuInferenceDevice[] = "GpuInference"; const char kDavinciDevice[] = "Davinci"; const char KNpuLog[] = "_npu_log"; diff --git a/mindspore/lite/src/cxx_api/context.cc b/mindspore/lite/src/cxx_api/context.cc index 37aedd46b9b..c4ef7deb2fa 100644 --- a/mindspore/lite/src/cxx_api/context.cc +++ b/mindspore/lite/src/cxx_api/context.cc @@ -36,7 +36,6 @@ constexpr auto kModelOptionProvider = "mindspore.option.provider"; constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device"; constexpr auto kModelOptionDeviceID = "mindspore.option.device_id"; constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID; -constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path"; constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path"; constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format"; constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map"; @@ -330,23 +329,6 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const { return GetValue(data_, kModelOptionAscend310DeviceID); } -void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector &cfg_path) { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return; - } - data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path); -} - -std::vector Ascend310DeviceInfo::GetDumpConfigPathChar() const { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return std::vector(); - } - const std::string &ref = GetValue(data_, kModelOptionAscend310DumpCfgPath); - return StringToChar(ref); -} - void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector &cfg_path) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; diff --git a/tests/st/cpp/model/tesn_control.cc b/tests/st/cpp/model/tesn_control.cc new file mode 100644 index 00000000000..42d86050d41 --- /dev/null +++ b/tests/st/cpp/model/tesn_control.cc @@ -0,0 +1,405 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "common/common_test.h" +#include "include/api/model.h" +#include "include/api/serialization.h" +#include "include/api/context.h" + +using namespace mindspore; + +static constexpr char kIfbyIfFile[] = "/home/workspace/mindspore_dataset/mindir/control/ifbyif.mindir"; +static constexpr char kSimpleWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/simple_while.mindir"; +static constexpr char kMixIfWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/mix_while_if.mindir"; +static constexpr char kRecursiveFile[] = "/home/workspace/mindspore_dataset/mindir/control/fibonacci.mindir"; +static constexpr char kSingleForFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_for.mindir"; +static constexpr char kSingleOrFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_or.mindir"; +static constexpr char kSingleSwitchFile[] = "/home/workspace/mindspore_dataset/mindir/control/switch_layer_net.mindir"; +static constexpr float kConstValue = 0.1234; +static const std::vector input_data(2 * 3 * 4 * 5, kConstValue); + +class TestControl : public ST::Common { + public: + TestControl() {} +}; + +TEST_F(TestControl, InferIfbyIf) { + auto context = ContextAutoSet(); + + Graph graph; + ASSERT_TRUE(Serialization::Load(kIfbyIfFile, ModelType::kMindIR, &graph)); + Model control_model; + ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); + + // assert inputs + std::vector inputs_before = control_model.GetInputs(); + ASSERT_EQ(5, inputs_before.size()); + EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32); + EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32); + EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeBool); + EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeBool); + EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeFloat32); + ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float)); + ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float)); + ASSERT_EQ(inputs_before[2].DataSize(), sizeof(bool)); + ASSERT_EQ(inputs_before[3].DataSize(), sizeof(bool)); + ASSERT_EQ(inputs_before[4].DataSize(), sizeof(float) * input_data.size()); + ASSERT_EQ(inputs_before[0].Shape().size(), 1); + EXPECT_EQ(inputs_before[0].Shape()[0], 1); + ASSERT_EQ(inputs_before[1].Shape().size(), 1); + EXPECT_EQ(inputs_before[1].Shape()[0], 1); + ASSERT_EQ(inputs_before[2].Shape().size(), 1); + EXPECT_EQ(inputs_before[2].Shape()[0], 1); + ASSERT_EQ(inputs_before[3].Shape().size(), 1); + EXPECT_EQ(inputs_before[3].Shape()[0], 1); + ASSERT_EQ(inputs_before[4].Shape().size(), 4); + EXPECT_EQ(inputs_before[4].Shape()[0], 2); + EXPECT_EQ(inputs_before[4].Shape()[1], 3); + EXPECT_EQ(inputs_before[4].Shape()[2], 4); + EXPECT_EQ(inputs_before[4].Shape()[3], 5); + + // prepare input + std::vector outputs; + std::vector inputs; + + float x = 2.345678, y = 1.234567; + bool cond1 = true, cond2 = false; + inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, + sizeof(float)); + inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, + sizeof(float)); + inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &cond1, + sizeof(bool)); + inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &cond2, + sizeof(bool)); + inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), input_data.data(), + sizeof(float) * input_data.size()); + + // infer + ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + + // assert output + ASSERT_TRUE(outputs.size() == 1); + auto out = outputs[0]; + ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size()); + auto out_data = out.Data(); + auto p = reinterpret_cast(out_data.get()); + for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) { + ASSERT_LE(std::abs(p[i] - kConstValue * 24), 1e-3); + } +} + +TEST_F(TestControl, InferSimpleWhile) { + auto context = ContextAutoSet(); + + Graph graph; + ASSERT_TRUE(Serialization::Load(kSimpleWhileFile, ModelType::kMindIR, &graph)); + Model control_model; + ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); + + // assert inputs + std::vector inputs_before = control_model.GetInputs(); + ASSERT_EQ(3, inputs_before.size()); + EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeBool); + EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeBool); + EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeFloat32); + ASSERT_EQ(inputs_before[0].DataSize(), sizeof(bool)); + ASSERT_EQ(inputs_before[1].DataSize(), sizeof(bool)); + ASSERT_EQ(inputs_before[2].DataSize(), sizeof(float) * input_data.size()); + ASSERT_EQ(inputs_before[0].Shape().size(), 1); + EXPECT_EQ(inputs_before[0].Shape()[0], 1); + ASSERT_EQ(inputs_before[1].Shape().size(), 1); + EXPECT_EQ(inputs_before[1].Shape()[0], 1); + ASSERT_EQ(inputs_before[2].Shape().size(), 4); + EXPECT_EQ(inputs_before[2].Shape()[0], 2); + EXPECT_EQ(inputs_before[2].Shape()[1], 3); + EXPECT_EQ(inputs_before[2].Shape()[2], 4); + EXPECT_EQ(inputs_before[2].Shape()[3], 5); + + // prepare input + std::vector outputs; + std::vector inputs; + { + bool x = true, y = false; + inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, + sizeof(bool)); + inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, + sizeof(bool)); + inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), + input_data.data(), sizeof(float) * input_data.size()); + } + + // infer + ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + + // assert output + ASSERT_TRUE(outputs.size() == 1); + auto out = outputs[0]; + ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size()); + auto out_data = out.Data(); + auto p = reinterpret_cast(out_data.get()); + for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) { + ASSERT_LE(std::abs(p[i] - kConstValue * 3), 1e-3); + } +} + +TEST_F(TestControl, InferRecursive) { + auto context = ContextAutoSet(); + + Graph graph; + ASSERT_TRUE(Serialization::Load(kRecursiveFile, ModelType::kMindIR, &graph)); + Model control_model; + ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); + + // assert inputs + std::vector inputs_before = control_model.GetInputs(); + ASSERT_EQ(1, inputs_before.size()); + EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32); + ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[0].Shape().size(), 1); + EXPECT_EQ(inputs_before[0].Shape()[0], 1); + + // prepare input + std::vector outputs; + std::vector inputs; + { + int32_t x = 7; + inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, + sizeof(int32_t)); + } + + // infer + ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + + // assert output + ASSERT_TRUE(outputs.size() == 1); + auto out = outputs[0]; + ASSERT_TRUE(out.DataSize() == sizeof(int32_t)); + auto out_data = out.Data(); + auto p = reinterpret_cast(out_data.get()); + ASSERT_EQ(*p, 21); +} + +TEST_F(TestControl, InferMixedWhileIf) { + auto context = ContextAutoSet(); + + Graph graph; + ASSERT_TRUE(Serialization::Load(kMixIfWhileFile, ModelType::kMindIR, &graph)); + Model control_model; + ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); + + // assert inputs + std::vector inputs_before = control_model.GetInputs(); + ASSERT_EQ(inputs_before.size(), 5); + EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32); + EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32); + EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32); + EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeInt32); + EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeInt32); + ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[3].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[4].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[0].Shape().size(), 1); + EXPECT_EQ(inputs_before[0].Shape()[0], 1); + ASSERT_EQ(inputs_before[1].Shape().size(), 1); + EXPECT_EQ(inputs_before[1].Shape()[0], 1); + ASSERT_EQ(inputs_before[2].Shape().size(), 1); + EXPECT_EQ(inputs_before[2].Shape()[0], 1); + ASSERT_EQ(inputs_before[3].Shape().size(), 1); + EXPECT_EQ(inputs_before[3].Shape()[0], 1); + ASSERT_EQ(inputs_before[4].Shape().size(), 1); + EXPECT_EQ(inputs_before[4].Shape()[0], 1); + + // prepare input + std::vector outputs; + std::vector inputs; + { + int32_t x = 2, y = 14, z = 1, c2 = 14, c4 = 0; + inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, + sizeof(int32_t)); + inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, + sizeof(int32_t)); + inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z, + sizeof(int32_t)); + inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &c2, + sizeof(int32_t)); + inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), &c4, + sizeof(int32_t)); + } + + // infer + ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + + // assert output + ASSERT_TRUE(outputs.size() == 1); + auto out = outputs[0]; + ASSERT_TRUE(out.DataSize() == sizeof(int32_t)); + auto out_data = out.Data(); + auto p = reinterpret_cast(out_data.get()); + ASSERT_EQ(*p, 350); +} + +TEST_F(TestControl, InferSingleFor) { + auto context = ContextAutoSet(); + + Graph graph; + ASSERT_TRUE(Serialization::Load(kSingleForFile, ModelType::kMindIR, &graph)); + Model control_model; + ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); + + // assert inputs + std::vector inputs_before = control_model.GetInputs(); + ASSERT_EQ(inputs_before.size(), 3); + EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32); + EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32); + EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32); + ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[0].Shape().size(), 1); + EXPECT_EQ(inputs_before[0].Shape()[0], 1); + ASSERT_EQ(inputs_before[1].Shape().size(), 1); + EXPECT_EQ(inputs_before[1].Shape()[0], 1); + ASSERT_EQ(inputs_before[2].Shape().size(), 1); + EXPECT_EQ(inputs_before[2].Shape()[0], 1); + + // prepare input + std::vector outputs; + std::vector inputs; + { + int32_t x = 2, y = 5, z = 4; + inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, + sizeof(int32_t)); + inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, + sizeof(int32_t)); + inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z, + sizeof(int32_t)); + } + + // infer + ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + + // assert output + ASSERT_TRUE(outputs.size() == 1); + auto out = outputs[0]; + ASSERT_TRUE(out.DataSize() == sizeof(int32_t)); + auto out_data = out.Data(); + auto p = reinterpret_cast(out_data.get()); + ASSERT_EQ(*p, 125); +} + +TEST_F(TestControl, InferSingleOr) { + auto context = ContextAutoSet(); + + Graph graph; + ASSERT_TRUE(Serialization::Load(kSingleOrFile, ModelType::kMindIR, &graph)); + Model control_model; + ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); + + // assert inputs + std::vector inputs_before = control_model.GetInputs(); + ASSERT_EQ(inputs_before.size(), 2); + EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32); + EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32); + ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 2); + ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float) * 2); + ASSERT_EQ(inputs_before[0].Shape().size(), 1); + EXPECT_EQ(inputs_before[0].Shape()[0], 2); + ASSERT_EQ(inputs_before[1].Shape().size(), 1); + EXPECT_EQ(inputs_before[1].Shape()[0], 2); + + // prepare input + std::vector outputs; + std::vector inputs; + { + static const std::vector input_data1 = {0, 1}; + static const std::vector input_data2 = {0, 0}; + inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), + input_data1.data(), sizeof(float) * input_data1.size()); + inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), + input_data2.data(), sizeof(int32_t) * input_data2.size()); + } + + // infer + ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + + // assert output + ASSERT_TRUE(outputs.size() == 1); + auto out = outputs[0]; + ASSERT_TRUE(out.DataSize() == sizeof(float)); + auto out_data = out.Data(); + auto p = reinterpret_cast(out_data.get()); + ASSERT_EQ(*p, 1); +} + +TEST_F(TestControl, InferSingleSwitch) { + auto context = ContextAutoSet(); + + Graph graph; + ASSERT_TRUE(Serialization::Load(kSingleSwitchFile, ModelType::kMindIR, &graph)); + Model control_model; + ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); + + // assert inputs + std::vector inputs_before = control_model.GetInputs(); + ASSERT_EQ(inputs_before.size(), 3); + EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32); + EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32); + EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32); + ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 224 * 224); + ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t)); + ASSERT_EQ(inputs_before[0].Shape().size(), 4); + EXPECT_EQ(inputs_before[0].Shape()[0], 1); + EXPECT_EQ(inputs_before[0].Shape()[1], 1); + EXPECT_EQ(inputs_before[0].Shape()[2], 224); + EXPECT_EQ(inputs_before[0].Shape()[3], 224); + ASSERT_EQ(inputs_before[1].Shape().size(), 1); + EXPECT_EQ(inputs_before[1].Shape()[0], 1); + ASSERT_EQ(inputs_before[2].Shape().size(), 1); + EXPECT_EQ(inputs_before[2].Shape()[0], 1); + + // prepare input + std::vector outputs; + std::vector inputs; + { + static const std::vector input_data1(1 * 1 * 224 * 224, 1); + int32_t index1 = 0; + int32_t index2 = -1; + inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), + input_data1.data(), sizeof(float) * input_data1.size()); + inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &index1, + sizeof(int32_t)); + inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &index2, + sizeof(int32_t)); + } + + // infer + ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + + // assert output + ASSERT_TRUE(outputs.size() == 1); + auto out = outputs[0]; + ASSERT_TRUE(out.DataSize() == sizeof(float) * 224 * 224); + auto out_data = out.Data(); + auto p = reinterpret_cast(out_data.get()); + for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) { + ASSERT_EQ(p[i], 1); + } +}