!22915 310 support cond graph

Merge pull request !22915 from zhoufeng/310-support-cond-graph
This commit is contained in:
i-robot 2021-09-08 06:46:55 +00:00 committed by Gitee
commit 516a74f985
37 changed files with 1555 additions and 122 deletions

View File

@ -292,9 +292,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
/// \return The device id. /// \return The device id.
uint32_t GetDeviceID() const; uint32_t GetDeviceID() const;
inline void SetDumpConfigPath(const std::string &cfg_path);
inline std::string GetDumpConfigPath() const;
/// \brief Set AIPP configuration file path. /// \brief Set AIPP configuration file path.
/// ///
/// \param[in] cfg_path AIPP configuration file path. /// \param[in] cfg_path AIPP configuration file path.
@ -379,9 +376,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
inline std::string GetBufferOptimizeMode() const; inline std::string GetBufferOptimizeMode() const;
private: private:
void SetDumpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetDumpConfigPathChar() const;
void SetInsertOpConfigPath(const std::vector<char> &cfg_path); void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetInsertOpConfigPathChar() const; std::vector<char> GetInsertOpConfigPathChar() const;
@ -406,9 +400,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
std::vector<char> GetBufferOptimizeModeChar() const; std::vector<char> GetBufferOptimizeModeChar() const;
}; };
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
SetInsertOpConfigPath(StringToChar(cfg_path)); SetInsertOpConfigPath(StringToChar(cfg_path));
} }

View File

@ -1121,7 +1121,8 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
return new_parameter; return new_parameter;
} }
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
bool common_opt) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode; std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = NewKernelGraph(); auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@ -1161,7 +1162,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
// Update Graph Dynamic Shape Attr // Update Graph Dynamic Shape Attr
UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
UpdateGraphAquireGilAttr(NOT_NULL(graph)); UpdateGraphAquireGilAttr(NOT_NULL(graph));
opt::BackendCommonOptimization(graph); if (common_opt) {
opt::BackendCommonOptimization(graph);
}
graph->SetInputNodes(); graph->SetInputNodes();
SetInputNodeUsage(graph, manager); SetInputNodeUsage(graph, manager);
graph->SetOptimizerFlag(); graph->SetOptimizerFlag();

View File

@ -116,7 +116,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph); bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
bool common_opt = true);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph, std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph); std::vector<KernelGraphPtr> *all_out_graph);

View File

@ -32,6 +32,7 @@ if(ENABLE_D OR ENABLE_ACL)
if(NOT ENABLE_D) if(NOT ENABLE_D)
list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>) list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_vm_obj>)
endif() endif()
endif() endif()

View File

@ -29,7 +29,6 @@ constexpr auto kModelOptionGPUTrtInferMode = "mindspore.option.gpu.trt_infer_mod
constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode"; constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode";
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID; constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID; constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path"; constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format"; constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map"; constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
@ -193,16 +192,6 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID); return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
} }
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path);
}
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);

View File

@ -21,8 +21,8 @@ namespace mindspore {
std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_; std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
std::mutex AclEnvGuard::global_acl_env_mutex_; std::mutex AclEnvGuard::global_acl_env_mutex_;
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { AclEnvGuard::AclEnvGuard() {
errno_ = aclInit(cfg_file.data()); errno_ = aclInit(nullptr);
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) { if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) {
MS_LOG(ERROR) << "Execute aclInit Failed"; MS_LOG(ERROR) << "Execute aclInit Failed";
return; return;
@ -38,18 +38,15 @@ AclEnvGuard::~AclEnvGuard() {
MS_LOG(INFO) << "Acl finalize success"; MS_LOG(INFO) << "Acl finalize success";
} }
std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) { std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv() {
std::shared_ptr<AclEnvGuard> acl_env; std::shared_ptr<AclEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_acl_env_mutex_); std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
acl_env = global_acl_env_; acl_env = global_acl_env_;
if (acl_env != nullptr) { if (acl_env != nullptr) {
MS_LOG(INFO) << "Acl has been initialized, skip."; MS_LOG(INFO) << "Acl has been initialized, skip.";
if (!cfg_file.empty()) {
MS_LOG(WARNING) << "Dump config file option " << cfg_file << " is ignored.";
}
} else { } else {
acl_env = std::make_shared<AclEnvGuard>(cfg_file); acl_env = std::make_shared<AclEnvGuard>();
aclError ret = acl_env->GetErrno(); aclError ret = acl_env->GetErrno();
if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) { if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) {
MS_LOG(ERROR) << "Execute aclInit Failed"; MS_LOG(ERROR) << "Execute aclInit Failed";

View File

@ -23,10 +23,10 @@
namespace mindspore { namespace mindspore {
class __attribute__((visibility("default"))) AclEnvGuard { class __attribute__((visibility("default"))) AclEnvGuard {
public: public:
explicit AclEnvGuard(std::string_view cfg_file); explicit AclEnvGuard();
~AclEnvGuard(); ~AclEnvGuard();
aclError GetErrno() const { return errno_; } aclError GetErrno() const { return errno_; }
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file); static std::shared_ptr<AclEnvGuard> GetAclEnv();
private: private:
static std::shared_ptr<AclEnvGuard> global_acl_env_; static std::shared_ptr<AclEnvGuard> global_acl_env_;

View File

@ -91,7 +91,7 @@ Status AclGraphImpl::InitEnv() {
return kSuccess; return kSuccess;
} }
acl_env_ = AclEnvGuard::GetAclEnv(""); acl_env_ = AclEnvGuard::GetAclEnv();
if (acl_env_ == nullptr) { if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed."; MS_LOG(ERROR) << "Acl init failed.";
return kMCDeviceError; return kMCDeviceError;

View File

@ -165,7 +165,7 @@ Status ModelProcess::InitInputsBuffer() {
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i); aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount); std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
const char *input_name_char = aclmdlGetInputNameByIndex(model_desc_, i); const char *input_name_char = aclmdlGetInputNameByIndex(model_desc_, i);
std::string input_name = (input_name_char == nullptr) ? input_name_char : std::string(); std::string input_name = (input_name_char != nullptr) ? input_name_char : std::string();
if (input_name.empty()) { if (input_name.empty()) {
MS_LOG(WARNING) << "Get name of input " << i << " failed."; MS_LOG(WARNING) << "Get name of input " << i << " failed.";
} }
@ -249,7 +249,7 @@ Status ModelProcess::InitOutputsBuffer() {
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i); aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount); std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
const char *output_name_char = aclmdlGetOutputNameByIndex(model_desc_, i); const char *output_name_char = aclmdlGetOutputNameByIndex(model_desc_, i);
std::string output_name = (output_name_char == nullptr) ? output_name_char : std::string(); std::string output_name = (output_name_char != nullptr) ? output_name_char : std::string();
if (output_name.empty()) { if (output_name.empty()) {
MS_LOG(WARNING) << "Get name of output " << i << " failed."; MS_LOG(WARNING) << "Get name of output " << i << " failed.";
} }

View File

@ -24,8 +24,6 @@
#include "acl/acl_base.h" #include "acl/acl_base.h"
namespace mindspore { namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
Status AclModel::Build() { Status AclModel::Build() {
MS_LOG(INFO) << "Start build model."; MS_LOG(INFO) << "Start build model.";
MS_EXCEPTION_IF_NULL(graph_); MS_EXCEPTION_IF_NULL(graph_);
@ -42,13 +40,8 @@ Status AclModel::Build() {
return kSuccess; return kSuccess;
} }
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_); std::shared_ptr<AclModelOptions> options = std::make_shared<AclModelOptions>(model_context_);
MS_EXCEPTION_IF_NULL(options); MS_EXCEPTION_IF_NULL(options);
std::string dump_cfg = options->GetDumpCfgPath();
if (!dump_cfg.empty()) {
MS_LOG(INFO) << "Options dump config file path " << dump_cfg;
(void)AclEnvGuard::GetAclEnv(dump_cfg);
}
std::string options_key = options->GenAclOptionsKey(); std::string options_key = options->GenAclOptionsKey();
std::shared_ptr<Graph> graph; std::shared_ptr<Graph> graph;
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) { if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
@ -72,7 +65,7 @@ Status AclModel::Build() {
} }
options->RenameInput(input_names); options->RenameInput(input_names);
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
model_converter_.set_options(options.get()); model_converter_.set_options(options);
auto om_data = model_converter_.LoadMindIR(func_graph); auto om_data = model_converter_.LoadMindIR(func_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) { if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed."; MS_LOG(ERROR) << "Load MindIR failed.";

View File

@ -47,7 +47,7 @@ class AclModel : public ModelImpl {
private: private:
ModelConverter model_converter_; ModelConverter model_converter_;
std::unique_ptr<AclModelOptions> options_; std::shared_ptr<AclModelOptions> options_;
std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_; std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_;
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,475 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/acl/acl_model_multi.h"
#include <vector>
#include <utility>
#include <map>
#include <string>
#include <algorithm>
#include "backend/session/session_basic.h"
#include "backend/session/session_factory.h"
#include "cxx_api/factory.h"
#include "vm/backend.h"
#include "vm/transform.h"
#include "acl/acl_rt.h"
#include "mindspore/core/load_mindir/infer_mindir.h"
#include "debug/trace.h"
namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
namespace {
class MSTensorRef : public BaseRef {
public:
static VectorRef Convert(const std::vector<MSTensor> &tensors) {
VectorRef res;
std::transform(tensors.begin(), tensors.end(), std::back_inserter(res),
[](const MSTensor &t) { return MSTensorRef(t); });
return res;
}
static std::vector<MSTensor> Convert(const BaseRef &args) {
std::vector<MSTensor> res;
if (utils::isa<VectorRef>(args)) {
VectorRef args_vec = utils::cast<VectorRef>(args);
for (size_t i = 0; i < args_vec.size(); ++i) {
const auto &item = args_vec[i];
if (!utils::isa<MSTensorRef>(item)) {
MS_LOG(EXCEPTION) << "Invalid item " << item.ToString() << " at index " << i;
}
auto wrapper = utils::cast<MSTensorRef>(item);
res.push_back(wrapper.ms_tensor_);
}
} else if (utils::isa<MSTensorRef>(args)) {
auto wrapper = utils::cast<MSTensorRef>(args);
res.push_back(wrapper.ms_tensor_);
} else {
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
}
return res;
}
MS_DECLARE_PARENT(MSTensorRef, BaseRef);
explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {}
~MSTensorRef() override = default;
const MSTensor &GetTensor() const { return ms_tensor_; }
std::shared_ptr<Base> copy() const override {
MSTensor *tensor = ms_tensor_.Clone();
auto res = std::make_shared<MSTensorRef>(static_cast<const MSTensor &>(*tensor));
MSTensor::DestroyTensorPtr(tensor);
return res;
}
uint32_t type() const override { return tid(); }
std::string ToString() const override { return ms_tensor_.Name(); }
bool operator==(const BaseRef &other) const override {
if (!utils::isa<MSTensorRef>(other)) {
return false;
}
return *this == utils::cast<MSTensorRef>(other);
}
bool operator==(MSTensorRef &other) {
return (ms_tensor_.Name() == other.ms_tensor_.Name()) && (ms_tensor_.Shape() == other.ms_tensor_.Shape()) &&
(ms_tensor_.MutableData() == other.ms_tensor_.MutableData()) &&
(ms_tensor_.DataSize() == other.ms_tensor_.DataSize()) &&
(ms_tensor_.DataType() == other.ms_tensor_.DataType());
}
private:
MSTensor ms_tensor_;
};
class MultiGraphAclSession : public session::SessionBasic {
public:
MultiGraphAclSession() = default;
~MultiGraphAclSession() override = default;
void Init(uint32_t device_id) override;
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs);
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
private:
std::map<GraphId, GraphCell> graphs_ = {};
std::shared_ptr<AclModelOptions> options_ = nullptr;
};
void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); }
GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
class FirstGraphModeGuard {
public:
explicit FirstGraphModeGuard(const std::shared_ptr<AclModelOptions> &options) : options_(options) {
if (options_ != nullptr) {
options_->SetFirstGraph(true);
}
}
~FirstGraphModeGuard() {
if (options_ != nullptr) {
options_->SetFirstGraph(false);
}
}
private:
std::shared_ptr<AclModelOptions> options_;
};
MS_LOG(INFO) << "Start MultiGraph Compile.";
auto kernel_graph = ConstructKernelGraph(lst, outputs, false);
MS_EXCEPTION_IF_NULL(kernel_graph);
ModelConverter model_converter_;
model_converter_.set_options(options_);
FirstGraphModeGuard guard(options_);
auto om_data = model_converter_.LoadMindIR(kernel_graph);
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
MS_LOG(ERROR) << "Load MindIR failed.";
return kMCFailed;
}
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
MS_EXCEPTION_IF_NULL(graph);
graphs_[kernel_graph->graph_id()] = GraphCell(graph);
MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id();
return kernel_graph->graph_id();
}
void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
MS_LOG(INFO) << "Start run graph " << graph_id;
auto iter = graphs_.find(graph_id);
if (iter == graphs_.end()) {
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " not found.";
}
std::vector<MSTensor> out_tensors;
auto ret = iter->second.Run(inputs, &out_tensors);
if (ret != kSuccess) {
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
}
(*outputs) = MSTensorRef::Convert(out_tensors);
}
class AclBackend : public compile::MsBackend {
public:
AclBackend(const std::string &name, const std::string &target, const std::shared_ptr<AclModelOptions> &options)
: MsBackend(name, target, options->GetDeviceID()) {
auto session = std::dynamic_pointer_cast<MultiGraphAclSession>(MsBackend::target_sess_);
MS_EXCEPTION_IF_NULL(session);
session->SetOptions(options);
}
~AclBackend() override = default;
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) override {
std::vector<MSTensor> inputs;
for (const auto &arg : args) {
if (!utils::isa<MSTensorRef>(arg)) {
MS_LOG(EXCEPTION) << "Invalid item " << arg.ToString();
}
auto wrapper = utils::cast<MSTensorRef>(arg);
inputs.emplace_back(wrapper.GetTensor());
}
VectorRef outputs;
MS_EXCEPTION_IF_NULL(target_sess_);
auto exec_sess = std::dynamic_pointer_cast<MultiGraphAclSession>(target_sess_);
MS_EXCEPTION_IF_NULL(exec_sess);
exec_sess->RunGraph(g, inputs, &outputs);
return outputs;
}
bool GetCond(const BaseRef &c, bool *value) override {
MS_EXCEPTION_IF_NULL(value);
if (!utils::isa<MSTensorRef>(c)) {
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
return false;
}
auto wrapper = utils::cast<MSTensorRef>(c);
if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) {
MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool.";
return false;
}
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
(*value) = *reinterpret_cast<const bool *>(data.get());
return true;
}
bool GetIndex(const BaseRef &c, int64_t *value) override {
MS_EXCEPTION_IF_NULL(value);
if (!utils::isa<MSTensorRef>(c)) {
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
return false;
}
auto wrapper = utils::cast<MSTensorRef>(c);
if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) {
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
auto value_int32 = *reinterpret_cast<const int32_t *>(data.get());
(*value) = static_cast<int64_t>(value_int32);
return true;
} else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) {
auto data = wrapper.GetTensor().Data();
if (data == nullptr) {
return false;
}
(*value) = *reinterpret_cast<const int64_t *>(data.get());
return true;
} else {
MS_LOG(ERROR) << "Index must be Int type.";
return false;
}
}
};
class AclCompileGraph : public compile::CompileGraph {
public:
explicit AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list)
: CompileGraph(backend, cut_list) {}
~AclCompileGraph() override = default;
void AddInst(const compile::Instruction &inst, const MSTensorRef &arg) {
VectorRef args;
args.push_back(arg);
compile::CompileGraph::AddInst(inst, args);
}
int64_t Ref(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
if (IsValueNode<FuncGraph>(node)) {
MS_LOG(DEBUG) << "Push graph.";
compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node));
} else {
MS_LOG(DEBUG) << "Push.";
if (IsValueNode<Primitive>(node)) {
MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
} else if (IsValueNode<tensor::Tensor>(node)) {
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(node->cast<ValueNodePtr>()->value());
MS_EXCEPTION_IF_NULL(tensor_node);
std::string name = "";
std::vector<int64_t> shape = tensor_node->shape_c();
DataType type = static_cast<DataType>(tensor_node->data_type_c());
auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size());
MSTensorRef mstensor_ref(*mstensor_node);
AddInst(compile::Instruction::kPush, mstensor_ref);
MSTensor::DestroyTensorPtr(mstensor_node);
} else {
compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node));
}
}
Push(node);
}
MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
<< ", return: " << slots_[node] - height_;
return slots_[node] - height_;
}
};
class AclCompileGraphs : public compile::CompileGraphs {
public:
explicit AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
const std::vector<PrimitivePtr> &cut_list)
: CompileGraphs(backend, cut_list) {
MS_EXCEPTION_IF_NULL(backend);
MS_LOG(DEBUG) << "Start vm: " << backend->name();
transform_ = std::make_shared<AclCompileGraph>(backend, cut_list);
Reset();
}
~AclCompileGraphs() override = default;
void Compile(const FuncGraphPtr &graph) override {
MS_LOG(DEBUG) << "Start";
mapping_[graph] = SizeToLong(insts_.size());
if (transform_ != nullptr) {
auto insts = transform_->Run(graph, false);
if (!insts.empty()) {
(void)insts_.insert(insts_.end(), insts.begin(), insts.end());
}
}
MS_LOG(DEBUG) << "End";
}
};
std::shared_ptr<compile::MsBackend> CreateBackend(const std::shared_ptr<AclModelOptions> &options) {
MS_EXCEPTION_IF_NULL(options);
return std::make_shared<AclBackend>(kMsConvert, kDavinciMultiGraphInferenceDevice, options);
}
bool HasMultiGraph(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(fg);
std::vector<AnfNodePtr> all_nodes = TopoSort(fg->get_return());
for (const auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(node)) {
MS_LOG(INFO) << fg->ToString() << " has FuncGraph node " << node->DebugString() << " is multi graph.";
return true;
}
}
return false;
}
} // namespace
Status AclModelMulti::Build() {
if (!is_multi_graph_.has_value()) {
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
}
if (!is_multi_graph_.value()) {
return AclModel::Build();
}
if (vm_ != nullptr) {
MS_LOG(INFO) << "Multi graph model has been built, skip.";
return kSuccess;
}
MS_LOG(INFO) << "Start build multi graph model.";
// perpare func graph
auto manager = MakeManager();
manager->AddFuncGraph(ModelImpl::GetFuncGraph());
ModelImpl::GetFuncGraph()->set_manager(manager);
// set inputs
SetInputs();
// infer mindir
abstract::AbstractBasePtrList broaded_args;
auto fg = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(fg);
const auto &inputs = fg->get_inputs();
(void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(broaded_args),
[](const AnfNodePtr &n) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(n);
auto abstract = n->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->GetValueTrack() != kAnyValue) {
return abstract->Broaden();
}
return abstract;
});
(void)InferMindir(ModelImpl::GetFuncGraph(), broaded_args);
// create vm
auto backend = CreateBackend(std::make_shared<AclModelOptions>(model_context_));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
backend->set_is_multi_graph_sink(false);
context_ptr->set_param<std::string>(MS_CTX_DEVICE_TARGET, kDavinciMultiGraphInferenceDevice);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
auto compile = std::make_shared<AclCompileGraphs>(backend, compile::GetMsNonlinearOps());
vm_ = compile->CompileAndLink(ModelImpl::GetFuncGraph());
backend_ = std::move(backend);
MS_LOG(INFO) << "Build multi graph model success.";
return kSuccess;
}
Status AclModelMulti::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
if (!is_multi_graph_.has_value()) {
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
}
if (!is_multi_graph_.value()) {
return AclModel::Predict(inputs, outputs);
}
Build();
MS_LOG(INFO) << "Start predict multi graph model.";
MS_EXCEPTION_IF_NULL(vm_);
MS_EXCEPTION_IF_NULL(outputs);
try {
(*outputs) = MSTensorRef::Convert(vm_->Eval(MSTensorRef::Convert(inputs)));
} catch (const std::exception &ex) {
MS_LOG(ERROR) << "Predict Failed, error: " << ex.what();
return kMCFailed;
}
if (inputs.size() != inputs_.size() && !inputs_.empty() != 0) {
MS_LOG(ERROR) << "Input Size is wrong.";
return kMCFailed;
}
if (inputs_.empty()) {
inputs_ = inputs;
} else {
for (size_t i = 0; i < inputs_.size(); ++i) {
auto input_tensor = MSTensor::CreateTensor(inputs_[i].Name(), inputs_[i].DataType(), inputs_[i].Shape(),
inputs[i].Data().get(), inputs[i].DataSize());
inputs_[i] = (*input_tensor);
MSTensor::DestroyTensorPtr(input_tensor);
}
}
outputs_ = *outputs;
MS_LOG(INFO) << "Predict multi graph model success.";
return kSuccess;
}
void AclModelMulti::SetInputs() {
if (inputs_.empty()) {
auto fg = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(fg);
const auto &inputs = fg->get_inputs();
for (const auto &in : inputs) {
auto input_param = std::dynamic_pointer_cast<Parameter>(in);
MS_EXCEPTION_IF_NULL(input_param);
MS_EXCEPTION_IF_NULL(input_param->abstract());
auto input_value = input_param->abstract()->GetValueTrack();
auto tensor = input_value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
std::vector<int64_t> shape = tensor->shape_c();
auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast<DataType>(tensor->data_type_c()),
shape, nullptr, tensor->Size());
inputs_.emplace_back(*input_tensor);
MSTensor::DestroyTensorPtr(input_tensor);
}
} else {
MS_LOG(DEBUG) << "inputs_ has been set.";
}
}
std::vector<MSTensor> AclModelMulti::GetInputs() {
if (!is_multi_graph_.has_value()) {
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
}
if (!is_multi_graph_.value()) {
return AclModel::GetInputs();
}
return inputs_;
}
std::vector<MSTensor> AclModelMulti::GetOutputs() {
if (!is_multi_graph_.has_value()) {
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
}
if (!is_multi_graph_.value()) {
return AclModel::GetOutputs();
}
return outputs_;
}
namespace session {
MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession);
} // namespace session
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H
#define MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H
#include "cxx_api/model/acl/acl_model.h"
#include <memory>
#include <optional>
#include <vector>
#include <string>
namespace mindspore {
namespace compile {
class MsBackend;
class FinalVM;
} // namespace compile
class AclModelMulti : public AclModel {
public:
AclModelMulti() : AclModel(), is_multi_graph_(std::nullopt) {}
~AclModelMulti() = default;
Status Build() override;
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
private:
void SetInputs();
std::optional<bool> is_multi_graph_;
std::shared_ptr<compile::MsBackend> backend_;
std::shared_ptr<compile::FinalVM> vm_;
std::vector<MSTensor> inputs_ = {};
std::vector<MSTensor> outputs_ = {};
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "cxx_api/model/acl/acl_model_options.h" #include "cxx_api/model/acl/acl_model_options.h"
#include <set>
#include <memory> #include <memory>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "external/ge/ge_api_types.h" #include "external/ge/ge_api_types.h"
@ -54,7 +55,6 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode(); op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode();
fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath(); fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath();
device_id_ = ascend310_info->GetDeviceID(); device_id_ = ascend310_info->GetDeviceID();
dump_cfg_path_ = ascend310_info->GetDumpConfigPath();
buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode(); buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode();
const char *soc_name = aclrtGetSocName(); const char *soc_name = aclrtGetSocName();
if (soc_name == nullptr) { if (soc_name == nullptr) {
@ -98,6 +98,14 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
{&dynamic_batch_size_, ge::ir_option::DYNAMIC_BATCH_SIZE}, {&dynamic_batch_size_, ge::ir_option::DYNAMIC_BATCH_SIZE},
{&dynamic_image_size_, ge::ir_option::DYNAMIC_IMAGE_SIZE}}; {&dynamic_image_size_, ge::ir_option::DYNAMIC_IMAGE_SIZE}};
const std::set<std::string> first_graph_options = {
ge::ir_option::INSERT_OP_FILE,
ge::ir_option::INPUT_FORMAT,
ge::ir_option::INPUT_SHAPE,
};
const std::set<std::string> multi_graph_unsupported_options = {ge::ir_option::OUTPUT_TYPE};
std::map<std::string, std::string> init_options; std::map<std::string, std::string> init_options;
std::map<std::string, std::string> build_options; std::map<std::string, std::string> build_options;
for (auto [ms_option, acl_option_key] : init_options_map) { for (auto [ms_option, acl_option_key] : init_options_map) {
@ -115,6 +123,20 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option; MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option;
build_options.emplace(acl_option_key, *ms_option); build_options.emplace(acl_option_key, *ms_option);
} }
// first_graph_flag has value means being multi graph mode
if (first_graph_flag_.has_value()) {
for (const auto &option : multi_graph_unsupported_options) {
build_options.erase(option);
}
// non-input graph
if (!first_graph_flag_) {
for (const auto &option : first_graph_options) {
build_options.erase(option);
}
}
}
return {init_options, build_options}; return {init_options, build_options};
} }

View File

@ -21,6 +21,7 @@
#include <map> #include <map>
#include <tuple> #include <tuple>
#include <memory> #include <memory>
#include <optional>
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/context.h" #include "include/api/context.h"
@ -32,11 +33,11 @@ class AclModelOptions {
~AclModelOptions() = default; ~AclModelOptions() = default;
std::string GenAclOptionsKey() const; std::string GenAclOptionsKey() const;
uint32_t GetDeviceID() const { return device_id_; } uint32_t GetDeviceID() const { return device_id_; }
std::string GetDumpCfgPath() const { return dump_cfg_path_; } void RenameInput(const std::vector<std::string> &);
void RenameInput(const std::vector<std::string> &name);
// return tuple<init_options, build_options> // return tuple<init_options, build_options>
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const; std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;
void SetFirstGraph(bool is_first_graph) { first_graph_flag_ = is_first_graph; }
private: private:
std::string output_node_; // todo: at convert.cc::BuildGraph(), no atc options std::string output_node_; // todo: at convert.cc::BuildGraph(), no atc options
@ -55,7 +56,7 @@ class AclModelOptions {
std::map<int, std::vector<int>> input_shape_map_; std::map<int, std::vector<int>> input_shape_map_;
// other options // other options
uint32_t device_id_; uint32_t device_id_;
std::string dump_cfg_path_; std::optional<bool> first_graph_flag_;
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -226,8 +226,9 @@ Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
std::map<std::string, std::string> init_options; std::map<std::string, std::string> init_options;
std::map<std::string, std::string> build_options; std::map<std::string, std::string> build_options;
if (options_ != nullptr) { auto option = options_.lock();
std::tie(init_options, build_options) = options_->GenAclOptions(); if (option != nullptr) {
std::tie(init_options, build_options) = option->GenAclOptions();
} }
return BuildAirModel(df_graph, init_options, build_options); return BuildAirModel(df_graph, init_options, build_options);

View File

@ -30,12 +30,12 @@
namespace mindspore { namespace mindspore {
class ModelConverter { class ModelConverter {
public: public:
ModelConverter() : options_(nullptr) {} ModelConverter() : options_() {}
~ModelConverter() = default; ~ModelConverter() = default;
Buffer LoadMindIR(const FuncGraphPtr &func_graph); Buffer LoadMindIR(const FuncGraphPtr &func_graph);
void set_options(AclModelOptions *options) { options_ = options; } void set_options(const std::weak_ptr<AclModelOptions> &options) { options_ = options; }
private: private:
transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph); transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph);
@ -43,7 +43,7 @@ class ModelConverter {
const std::map<std::string, std::string> &build_options); const std::map<std::string, std::string> &build_options);
Buffer LoadAscendIRInner(const Buffer &model_data); Buffer LoadAscendIRInner(const Buffer &model_data);
AclModelOptions *options_; std::weak_ptr<AclModelOptions> options_;
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H #endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H

View File

@ -24,7 +24,6 @@ namespace dataset {
Status AscendResource::InitResource(uint32_t device_id) { Status AscendResource::InitResource(uint32_t device_id) {
ResourceInfo resource; ResourceInfo resource;
resource.aclConfigPath = "";
resource.deviceIds.insert(device_id); resource.deviceIds.insert(device_id);
ascend_resource_ = ResourceManager::GetInstance(); ascend_resource_ = ResourceManager::GetInstance();
APP_ERROR ret = ascend_resource_->InitResource(resource); APP_ERROR ret = ascend_resource_->InitResource(resource);

View File

@ -77,7 +77,6 @@ Status DvppCropJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shared
imageinfo.heightStride = yuv_shape_[3]; imageinfo.heightStride = yuv_shape_[3];
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
ResourceInfo resource; ResourceInfo resource;
resource.aclConfigPath = "";
resource.deviceIds.insert(0); resource.deviceIds.insert(0);
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance(); std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
APP_ERROR ret = instance->InitResource(resource); APP_ERROR ret = instance->InitResource(resource);

View File

@ -70,7 +70,6 @@ Status DvppDecodeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
imageInfo.lenOfByte = filesize; imageInfo.lenOfByte = filesize;
imageInfo.data = static_cast<void *>(buffer); imageInfo.data = static_cast<void *>(buffer);
ResourceInfo resource; ResourceInfo resource;
resource.aclConfigPath = "";
resource.deviceIds.insert(0); resource.deviceIds.insert(0);
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance(); std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
APP_ERROR ret = instance->InitResource(resource); APP_ERROR ret = instance->InitResource(resource);

View File

@ -68,7 +68,6 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr<Tensor> &input, std::share
imageInfo.lenOfByte = filesize; imageInfo.lenOfByte = filesize;
imageInfo.data = static_cast<void *>(buffer); imageInfo.data = static_cast<void *>(buffer);
ResourceInfo resource; ResourceInfo resource;
resource.aclConfigPath = "";
resource.deviceIds.insert(0); resource.deviceIds.insert(0);
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance(); std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
APP_ERROR ret = instance->InitResource(resource); APP_ERROR ret = instance->InitResource(resource);

View File

@ -69,7 +69,6 @@ Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr<Tensor> &input,
imageInfo.lenOfByte = filesize; imageInfo.lenOfByte = filesize;
imageInfo.data = static_cast<void *>(buffer); imageInfo.data = static_cast<void *>(buffer);
ResourceInfo resource; ResourceInfo resource;
resource.aclConfigPath = "";
resource.deviceIds.insert(0); resource.deviceIds.insert(0);
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance(); std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
APP_ERROR ret = instance->InitResource(resource); APP_ERROR ret = instance->InitResource(resource);

View File

@ -68,7 +68,6 @@ Status DvppDecodeResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std
imageInfo.lenOfByte = filesize; imageInfo.lenOfByte = filesize;
imageInfo.data = static_cast<void *>(buffer); imageInfo.data = static_cast<void *>(buffer);
ResourceInfo resource; ResourceInfo resource;
resource.aclConfigPath = "";
resource.deviceIds.insert(0); resource.deviceIds.insert(0);
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance(); std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
APP_ERROR ret = instance->InitResource(resource); APP_ERROR ret = instance->InitResource(resource);

View File

@ -78,7 +78,6 @@ Status DvppResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
imageinfo.heightStride = yuv_shape_[3]; imageinfo.heightStride = yuv_shape_[3];
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
ResourceInfo resource; ResourceInfo resource;
resource.aclConfigPath = "";
resource.deviceIds.insert(0); resource.deviceIds.insert(0);
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance(); std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
APP_ERROR ret = instance->InitResource(resource); APP_ERROR ret = instance->InitResource(resource);

View File

@ -85,22 +85,11 @@ APP_ERROR ResourceManager::InitResource(ResourceInfo &resourceInfo) {
MS_LOG(INFO) << "Acl has been initialized, skip."; MS_LOG(INFO) << "Acl has been initialized, skip.";
return APP_ERR_OK; return APP_ERR_OK;
} }
std::string &aclConfigPath = resourceInfo.aclConfigPath;
APP_ERROR ret = APP_ERR_OK; APP_ERROR ret = APP_ERR_OK;
if (aclConfigPath.length() == 0) { acl_env_ = mindspore::AclEnvGuard::GetAclEnv();
// Init acl without aclconfig
acl_env_ = mindspore::AclEnvGuard::GetAclEnv("");
} else {
ret = ExistFile(aclConfigPath);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Acl config file not exist, ret = " << ret << ".";
return ret;
}
acl_env_ = mindspore::AclEnvGuard::GetAclEnv(aclConfigPath);
}
if (acl_env_ == nullptr) { if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Failed to init acl."; MS_LOG(ERROR) << "Failed to init acl.";
return ret; return APP_ERR_COMM_FAILURE;
} }
std::copy(resourceInfo.deviceIds.begin(), resourceInfo.deviceIds.end(), std::back_inserter(deviceIds_)); std::copy(resourceInfo.deviceIds.begin(), resourceInfo.deviceIds.end(), std::back_inserter(deviceIds_));
MS_LOG(INFO) << "Initialized acl successfully."; MS_LOG(INFO) << "Initialized acl successfully.";

View File

@ -53,7 +53,6 @@ struct DeviceResInfo {
struct ResourceInfo { struct ResourceInfo {
std::set<int> deviceIds; std::set<int> deviceIds;
std::string aclConfigPath;
std::string singleOpFolderPath; std::string singleOpFolderPath;
std::unordered_map<int, DeviceResInfo> deviceResInfos; // map <deviceId, deviceResourceInfo> std::unordered_map<int, DeviceResInfo> deviceResInfos; // map <deviceId, deviceResourceInfo>
}; };

View File

@ -1872,7 +1872,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
MS_LOG(WARNING) << "set attr value for const failed"; MS_LOG(WARNING) << "set attr value for const failed";
} }
#if (defined ENABLE_GE)
auto const_op = std::static_pointer_cast<Constant>(op); auto const_op = std::static_pointer_cast<Constant>(op);
if (const_op == nullptr) { if (const_op == nullptr) {
MS_LOG(ERROR) << "Get Constant operator failed"; MS_LOG(ERROR) << "Get Constant operator failed";
@ -1881,7 +1880,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
auto ge_tensor = const_op->get_attr_value(); auto ge_tensor = const_op->get_attr_value();
auto ge_desc = ge_tensor.GetTensorDesc(); auto ge_desc = ge_tensor.GetTensorDesc();
(void)const_op->update_output_desc_y(ge_desc); (void)const_op->update_output_desc_y(ge_desc);
#endif
op_cache_[node.get()] = op; op_cache_[node.get()] = op;
return op_cache_[node.get()]; return op_cache_[node.get()];

View File

@ -317,11 +317,6 @@ size_t OpAdapterImpl::GetCustomOpOutputSize(const CusOperatorPtr &cus_op) {
std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
const std::string &format) { const std::string &format) {
if (shape_ptr == nullptr) {
MS_LOG(ERROR) << "Shape ptr is nullptr";
return nullptr;
}
if (type == nullptr) { if (type == nullptr) {
MS_LOG(ERROR) << "Type ptr is nullptr"; MS_LOG(ERROR) << "Type ptr is nullptr";
return nullptr; return nullptr;
@ -331,8 +326,8 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::Sh
if (kObjectTypeTensorType == me_type) { if (kObjectTypeTensorType == me_type) {
me_type = dyn_cast<TensorType>(type)->element()->type_id(); me_type = dyn_cast<TensorType>(type)->element()->type_id();
} }
auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format);
return desc; return TransformUtil::GetGeTensorDesc((shape_ptr == nullptr) ? ShapeVector{} : shape_ptr->shape(), me_type, format);
} }
Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
@ -472,7 +467,7 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base
return; return;
} }
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Op name is " << op->GetName(); MS_LOG(INFO) << "Op name is " << op->GetName() << " anf is " << node->DebugString();
auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp); auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp); auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);

View File

@ -78,7 +78,7 @@ class MsBackend : public Backend {
~MsBackend() override = default; ~MsBackend() override = default;
LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = ""); LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = "");
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); virtual VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
VectorRef MsSimuRunGraph(const GraphId &g); VectorRef MsSimuRunGraph(const GraphId &g);
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override; GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
@ -90,7 +90,7 @@ class MsBackend : public Backend {
void SetDebugger() override; void SetDebugger() override;
#endif #endif
private: protected:
session::SessionPtr target_sess_; session::SessionPtr target_sess_;
session::SessionPtr other_sess_; session::SessionPtr other_sess_;
std::string target_device_; std::string target_device_;

View File

@ -145,6 +145,22 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
} }
} }
void CompileGraph::PushInputs(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters = graph->parameters();
for (size_t i = parameters.size(); i != 0; i--) {
MS_EXCEPTION_IF_NULL(parameters[i - 1]);
auto param = parameters[i - 1]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (param->has_default()) {
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip.";
continue;
}
Push(param);
MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(true);
}
}
int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) { int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
MS_EXCEPTION_IF_NULL(segment); MS_EXCEPTION_IF_NULL(segment);
MS_LOG(DEBUG) << "LinConvert start"; MS_LOG(DEBUG) << "LinConvert start";
@ -257,11 +273,16 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
return true; return true;
} }
InstSet CompileGraph::Run(const FuncGraphPtr &graph) { InstSet CompileGraph::Run(const FuncGraphPtr &graph, bool push_weight) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
Reset(); Reset();
PushParameters(graph); if (push_weight) {
PushParameters(graph);
} else {
PushInputs(graph);
}
int64_t param_height = height_; int64_t param_height = height_;
MS_EXCEPTION_IF_NULL(graph->get_return()); MS_EXCEPTION_IF_NULL(graph->get_return());
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);

View File

@ -53,14 +53,14 @@ class CompileGraph {
public: public:
explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops); explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
~CompileGraph() = default; virtual ~CompileGraph() = default;
InstSet Run(const FuncGraphPtr &func_graph); InstSet Run(const FuncGraphPtr &func_graph, bool push_weight = true);
bool IsCut(const AnfNodePtr &node); bool IsCut(const AnfNodePtr &node);
void Push(const AnfNodePtr &node); void Push(const AnfNodePtr &node);
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
void Ret(int64_t nargs); void Ret(int64_t nargs);
int64_t Ref(const AnfNodePtr &node); virtual int64_t Ref(const AnfNodePtr &node);
void set_height(int64_t h) { void set_height(int64_t h) {
height_ = h; height_ = h;
@ -76,8 +76,9 @@ class CompileGraph {
inst_.clear(); inst_.clear();
} }
private: protected:
void PushParameters(const FuncGraphPtr &func_graph); void PushParameters(const FuncGraphPtr &func_graph);
void PushInputs(const FuncGraphPtr &graph);
bool Compile(const FuncGraphPtr &func_graph); bool Compile(const FuncGraphPtr &func_graph);
int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = ""); int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
@ -114,18 +115,18 @@ class CompileGraphs {
public: public:
explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops); explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
~CompileGraphs() = default; virtual ~CompileGraphs() = default;
void Reset() { void Reset() {
insts_.clear(); insts_.clear();
mapping_.clear(); mapping_.clear();
} }
void Compile(const FuncGraphPtr &func_graph); virtual void Compile(const FuncGraphPtr &func_graph);
FinalVMPtr Link(); FinalVMPtr Link();
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
private: protected:
InstSet insts_; InstSet insts_;
std::unordered_map<FuncGraphPtr, int64_t> mapping_; std::unordered_map<FuncGraphPtr, int64_t> mapping_;
CompileGraphPtr transform_; CompileGraphPtr transform_;

View File

@ -424,17 +424,11 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape(); ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
auto out_shape = BroadcastShape(x_shape, y_shape); auto out_shape = BroadcastShape(x_shape, y_shape);
if (out_shape.empty()) {
MS_LOG(EXCEPTION) << "Less op BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min); auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max); auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
auto output_type = std::make_shared<Bool>(); auto output_type = std::make_shared<Bool>();
auto ret = return std::make_shared<AbstractTensor>(output_type,
std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max)); std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
return ret;
} }
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,503 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "load_mindir/infer_mindir.h"
#include <deque>
#include <set>
#include <map>
#include <memory>
#include <algorithm>
#include <string>
#include "ir/func_graph.h"
#include "abstract/abstract_function.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace {
class MindIREngine {
public:
explicit MindIREngine(const FuncGraphPtr &root) : func_graph_(root), nodeuser_map_(root->manager()->node_users()) {}
~MindIREngine() = default;
MindIREngine(const MindIREngine &) = delete;
MindIREngine &operator=(const MindIREngine &) = delete;
bool InferShape(const AbstractBasePtrList &args);
private:
using AbstractBasePtrListPtr = std::shared_ptr<AbstractBasePtrList>;
void Init(const AbstractBasePtrList &args);
static AbstractBasePtr InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
void EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
void EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
void EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node);
void InferParameter(const AnfNodePtr &node);
void InferValueNode(const AnfNodePtr &node);
void InferCNode(const AnfNodePtr &node);
void EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &abstractFunc, const CNodePtr &node,
const AbstractBasePtrListPtr &args);
void EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
const AbstractBasePtrListPtr &args);
void EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
const AbstractBasePtrListPtr &args);
void EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
const AbstractBasePtrListPtr &args);
bool CheckCNodeNotReady(const CNodePtr &node);
void UpdateReady(const AnfNodePtr &node);
void SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result);
AbstractBasePtr GetCNodeOperatorAbstract(const AnfNodePtr &node);
FuncGraphPtr func_graph_;
std::map<AnfNodePtr, int> node_input_depends_;
std::map<AnfNodePtr, AbstractBasePtr> infer_resut_;
std::map<std::string, AbstractBasePtr> func_graph_result_;
std::map<std::string, std::set<AnfNodePtr>> func_graph_visited_;
std::deque<AnfNodePtr> ready_;
std::set<AnfNodePtr> todo_;
NodeUsersMap nodeuser_map_;
};
// Infer the root function graph.
bool MindIREngine::InferShape(const AbstractBasePtrList &args) {
Init(args);
while (!ready_.empty()) {
auto current = ready_.front();
MS_EXCEPTION_IF_NULL(current);
ready_.pop_front();
if (current->isa<CNode>()) {
InferCNode(current);
} else if (current->isa<ValueNode>()) {
InferValueNode(current);
} else if (current->isa<Parameter>()) {
InferParameter(current);
} else {
MS_LOG(WARNING) << " There is something changed. Please check the code.";
}
}
// Set abstract of node.
for (const auto &item : infer_resut_) {
item.first->set_abstract(item.second);
}
if (todo_.empty()) {
MS_LOG(DEBUG) << "Finish to Infere.";
return true;
}
MS_LOG(WARNING) << "Not finished to infer: " << todo_.size();
for (const auto &node : todo_) {
MS_LOG(DEBUG) << "Node uninfered: " << node->DebugString();
}
return false;
}
void MindIREngine::Init(const AbstractBasePtrList &args) {
MS_EXCEPTION_IF_NULL(func_graph_);
auto manager = func_graph_->manager();
MS_EXCEPTION_IF_NULL(manager);
for (const auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
todo_.insert(node);
node_input_depends_[node] = cnode->inputs().size();
} else if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (param->has_default()) {
node_input_depends_[node] = 0;
infer_resut_[node] = param->default_param()->ToAbstract();
ready_.push_back(node);
} else {
node_input_depends_[node] = 1;
todo_.insert(node);
}
} else {
// Value Node
node_input_depends_[node] = 0;
ready_.push_back(node);
}
}
auto inputs = func_graph_->get_inputs();
if (inputs.size() != args.size()) {
MS_LOG(EXCEPTION) << "The input parameters is not Compatible. mindir:" << inputs.size()
<< " inputs: " << args.size() << " FuncGraph:" << func_graph_->ToString();
}
// Root Func Parameters
for (size_t i = 0; i < args.size(); ++i) {
this->SaveNodeInferResult(inputs[i], args[i]);
}
MS_LOG(DEBUG) << "Finish init. Size of nodes:" << manager->all_nodes().size();
}
// Infer primitive using C++ implement.
AbstractBasePtr MindIREngine::InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(prim);
try {
static auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(prim);
if (ret != prim_eval_implement_map.end()) {
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
return ret->second.infer_shape_impl_(nullptr, prim, args_spec_list);
} else {
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
static auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
auto ret_backend = prim_backend_eval_impl_map.find(prim);
if (ret_backend != prim_backend_eval_impl_map.end()) {
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list);
}
}
MS_LOG(WARNING) << "Get infer shape function failed, primitive name:" << prim->name()
<< " primitive type:" << prim->type_name() << " It will keep the prevalue witch danger.";
} catch (const std::exception &ex) {
MS_LOG(WARNING) << "Catch primitive:" << prim->ToString() << " InferPrimitiveShape exception:" << ex.what()
<< " It will keep the prevalue witch danger.";
}
return nullptr;
}
void MindIREngine::EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
const AbstractBasePtrListPtr &args) {
AbstractBasePtrList args_spec_list;
// Args has been resolved by partial
if (args != nullptr) {
args_spec_list.insert(args_spec_list.end(), args->begin(), args->end());
} else {
(void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(args_spec_list),
[this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
}
// Call C++ infer
auto result = InferPrimitiveShape(prim, args_spec_list);
if (result == nullptr) {
MS_LOG(INFO) << node->ToString()
<< " can't be inferred shape. It will keep the prevalue witch danger. Prim: " << prim->ToString();
result = node->abstract();
}
SaveNodeInferResult(node, result);
}
void MindIREngine::EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node) {
if (node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
}
auto result = infer_resut_[node->inputs()[1]];
auto funcName = node->func_graph()->ToString();
auto it = func_graph_result_.find(funcName);
if (it != func_graph_result_.end()) {
result = result->Join(it->second);
}
this->func_graph_result_[funcName] = result;
SaveNodeInferResult(node, result);
MS_LOG(DEBUG) << funcName << " result: " << result->ToString();
// Set the result of the node whose Operator is this funcGraph
for (const auto &item : func_graph_visited_[funcName]) {
SaveNodeInferResult(item, result);
}
}
void MindIREngine::EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
const AbstractBasePtrListPtr &args) {
// Args has been resolved
if (args != nullptr) {
if (args->size() < 2) {
MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
}
auto real_func = (*args)[0]->cast<abstract::AbstractFuncAtomPtr>();
if (real_func == nullptr) {
MS_LOG(EXCEPTION) << (*args)[0]->ToString() << " is not a function abstract.";
}
AbstractBasePtrList partial_args_list;
partial_args_list.insert(partial_args_list.end(), args->begin() + 1, args->end());
auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
SaveNodeInferResult(node, partial_func);
return;
}
// Not Resolved.
if (node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
}
auto &func = infer_resut_[node->inputs()[1]];
auto real_func = func->cast<abstract::AbstractFuncAtomPtr>();
if (real_func == nullptr) {
MS_LOG(EXCEPTION) << func->ToString() << " is not a function abstract.";
}
AbstractBasePtrList partial_args_list;
(void)std::transform(node->inputs().begin() + 2, node->inputs().end(), std::back_inserter(partial_args_list),
[this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
SaveNodeInferResult(node, partial_func);
}
void MindIREngine::EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
const AbstractBasePtrListPtr &args) {
AbstractBasePtrListPtr partial_args_list = std::make_shared<AbstractBasePtrList>();
// Join arguments in partial and the rest arguments from args_conf_list.
auto func_args = func->args();
partial_args_list->insert(partial_args_list->end(), func_args.begin(), func_args.end());
if (args == nullptr) {
// Not Recursive
(void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(*partial_args_list),
[this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
} else {
// Recursive
partial_args_list->insert(partial_args_list->end(), args->begin(), args->end());
}
// Get real function
abstract::AbstractFuncAtomPtrList abstractFuncList;
auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
abstractFuncList.push_back(poss);
};
func->fn()->Visit(build_fuction);
for (const auto &abstractFunc : abstractFuncList) {
EvalAbstractFunction(abstractFunc, node, partial_args_list);
}
}
void MindIREngine::SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result) {
auto answer = result;
auto it = infer_resut_.find(node);
if (it != infer_resut_.end()) {
MS_LOG(DEBUG) << node->ToString() << " result: " << it->second->ToString();
answer = result->Join(it->second);
if (*answer == *(it->second)) {
MS_LOG(DEBUG) << node->ToString() << " The value is not changed.";
return;
}
}
MS_LOG(DEBUG) << node->ToString() << " result: " << answer->ToString();
infer_resut_[node] = answer;
UpdateReady(node);
}
void MindIREngine::EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
const AbstractBasePtrListPtr &args) {
auto prim = func->prim();
// Return Primitive
if (prim->name() == prim::kPrimReturn->name()) {
EvalReturnPrimitive(prim, node);
return;
}
// Partial Primitive
if (prim->name() == prim::kPrimPartial->name()) {
EvalPartialPrimitive(prim, node, args);
return;
}
// common Primitive
EvalCommonPrimitive(prim, node, args);
}
bool MindIREngine::CheckCNodeNotReady(const CNodePtr &node) {
int depend = 0;
for (const auto &input : node->inputs()) {
depend += infer_resut_.find(input) != infer_resut_.end() ? 0 : 1;
}
this->node_input_depends_[node] = depend;
return depend;
}
void MindIREngine::EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
const AbstractBasePtrListPtr &args) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func);
MS_EXCEPTION_IF_NULL(func->func_graph());
// Has Processd
MS_LOG(DEBUG) << node->ToString() << " FuncGraph: " << func->ToString();
auto funcName = func->func_graph()->ToString();
auto it = func_graph_result_.find(funcName);
if (it != func_graph_result_.end()) {
MS_LOG(DEBUG) << "The abstract of " << node->ToString() << " = abstract of " << func->ToString();
SaveNodeInferResult(node, it->second);
// Process only one return valueNode function graph
auto func_inputs = func->func_graph()->get_inputs();
// args has been resolved in partial.
if (args != nullptr) {
if (func_inputs.size() != args->size()) {
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
<< " CNode:" << node->DebugString() << " input size:" << args->size();
}
for (size_t i = 0; i < func_inputs.size(); ++i) {
infer_resut_[func_inputs[i]] =
(*args)[i]; // Not use SaveNodeInferResult because this function has been evaluated.
todo_.erase(func_inputs[i]);
}
return;
}
// args is not resolved.
auto &cnode_inputs = node->inputs();
if (func_inputs.size() != cnode_inputs.size() - 1) {
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
<< " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
}
for (size_t i = 0; i < func_inputs.size(); ++i) {
infer_resut_[func_inputs[i]] = infer_resut_[cnode_inputs[i + 1]];
todo_.erase(func_inputs[i]);
}
return;
}
// Be handling
auto visitIt = func_graph_visited_.find(funcName);
if (visitIt != func_graph_visited_.end()) {
visitIt->second.insert(node);
return;
}
func_graph_visited_[funcName] = std::set<AnfNodePtr>({node});
// Call the funcGraph
auto func_inputs = func->func_graph()->get_inputs();
// args has been resolved in partial.
if (args != nullptr) {
if (func_inputs.size() != args->size()) {
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
<< " CNode:" << node->DebugString() << " input size:" << args->size();
}
for (size_t i = 0; i < func_inputs.size(); ++i) {
SaveNodeInferResult(func_inputs[i], (*args)[i]);
}
return;
}
// args is not resolved.
auto &cnode_inputs = node->inputs();
if (func_inputs.size() != cnode_inputs.size() - 1) {
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
<< " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
}
for (size_t i = 0; i < func_inputs.size(); ++i) {
SaveNodeInferResult(func_inputs[i], infer_resut_[cnode_inputs[i + 1]]);
}
}
void MindIREngine::InferParameter(const AnfNodePtr &node) { UpdateReady(node); }
void MindIREngine::InferValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(node);
MS_EXCEPTION_IF_NULL(value);
AbstractBasePtr result;
if (value->isa<FuncGraph>()) {
auto func_graph = value->cast<FuncGraphPtr>();
auto temp_context = abstract::AnalysisContext::DummyContext();
result = std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, node);
} else if (value->isa<Primitive>()) {
auto prim = value->cast<PrimitivePtr>();
result = std::make_shared<abstract::PrimitiveAbstractClosure>(prim, node);
} else {
result = value->ToAbstract();
}
if (result->isa<abstract::AbstractTensor>()) {
result = result->Broaden();
}
SaveNodeInferResult(node, result);
}
AbstractBasePtr MindIREngine::GetCNodeOperatorAbstract(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto op = cnode->inputs()[0];
auto it = infer_resut_.find(op);
if (it != infer_resut_.end()) {
return it->second;
}
MS_LOG(EXCEPTION) << "Can't get the abstract of Node:" << op->DebugString();
}
// If args is nullPtr, it is called by InferCNode, else it is called recursively by EvalPartialAbastract.
void MindIREngine::EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node,
const AbstractBasePtrListPtr &args) {
MS_EXCEPTION_IF_NULL(func);
if (func->isa<abstract::PrimitiveAbstractClosure>()) {
// C++ Primitive
auto prim = func->cast<abstract::PrimitiveAbstractClosurePtr>();
EvalPrimitiveAbastract(prim, node, args);
} else if (func->isa<abstract::FuncGraphAbstractClosure>()) {
// FuncGraph
auto funcGraph = func->cast<abstract::FuncGraphAbstractClosurePtr>();
EvalFuncGraphAbastract(funcGraph, node, args);
} else if (func->isa<abstract::PartialAbstractClosure>()) {
// Partial
auto partialPrim = func->cast<abstract::PartialAbstractClosurePtr>();
EvalPartialAbastract(partialPrim, node, args);
} else {
MS_LOG(EXCEPTION) << "MindIR can't process the abstractFunc: " << func->DumpText();
}
}
void MindIREngine::UpdateReady(const AnfNodePtr &node) {
todo_.erase(node);
auto it = nodeuser_map_.find(node);
if (it == nodeuser_map_.end()) {
return;
}
const auto &users = it->second;
MS_LOG(DEBUG) << node->ToString() << " has users: " << users.size();
for (const auto &user : users) {
int count = node_input_depends_[user.first];
node_input_depends_[user.first] = count - 1;
if (count <= 1) {
ready_.push_back(user.first);
MS_LOG(DEBUG) << "Node:" << user.first->ToString() << " is ready.";
if (count < 1) {
MS_LOG(INFO) << " There is something to do. Node:" << node->ToString() << " user:" << user.first->DebugString();
}
}
}
}
void MindIREngine::InferCNode(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (CheckCNodeNotReady(cnode)) {
MS_LOG(INFO) << "The node is not ready: " << cnode->DebugString();
return;
}
AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode);
if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
MS_LOG(EXCEPTION) << "EvalCNode eval Undetermined";
}
abstract::AbstractFunctionPtr func = dyn_cast<abstract::AbstractFunction>(possible_func);
if (func == nullptr) {
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
}
abstract::AbstractFuncAtomPtrList abstractFuncList;
auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
abstractFuncList.push_back(poss);
};
func->Visit(build_fuction);
for (const auto &abstractFunc : abstractFuncList) {
EvalAbstractFunction(abstractFunc, cnode, nullptr);
}
}
} // namespace
bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args) {
auto engine = std::make_shared<MindIREngine>(root);
return engine->InferShape(args);
}
} // namespace mindspore

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H
#define MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H
#include "base/base.h"
#include "ir/anf.h"
namespace mindspore {
bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args);
} // namespace mindspore
#endif // MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H

View File

@ -51,6 +51,7 @@ const char kCPUDevice[] = "CPU";
const char kGPUDevice[] = "GPU"; const char kGPUDevice[] = "GPU";
const char kAscendDevice[] = "Ascend"; const char kAscendDevice[] = "Ascend";
const char kDavinciInferenceDevice[] = "AscendInference"; const char kDavinciInferenceDevice[] = "AscendInference";
const char kDavinciMultiGraphInferenceDevice[] = "AscendMultiGraphInference";
const char kGpuInferenceDevice[] = "GpuInference"; const char kGpuInferenceDevice[] = "GpuInference";
const char kDavinciDevice[] = "Davinci"; const char kDavinciDevice[] = "Davinci";
const char KNpuLog[] = "_npu_log"; const char KNpuLog[] = "_npu_log";

View File

@ -36,7 +36,6 @@ constexpr auto kModelOptionProvider = "mindspore.option.provider";
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device"; constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id"; constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID; constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path"; constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format"; constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map"; constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
@ -330,23 +329,6 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID); return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
} }
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path);
}
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";

View File

@ -0,0 +1,405 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <vector>
#include "common/common_test.h"
#include "include/api/model.h"
#include "include/api/serialization.h"
#include "include/api/context.h"
using namespace mindspore;
static constexpr char kIfbyIfFile[] = "/home/workspace/mindspore_dataset/mindir/control/ifbyif.mindir";
static constexpr char kSimpleWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/simple_while.mindir";
static constexpr char kMixIfWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/mix_while_if.mindir";
static constexpr char kRecursiveFile[] = "/home/workspace/mindspore_dataset/mindir/control/fibonacci.mindir";
static constexpr char kSingleForFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_for.mindir";
static constexpr char kSingleOrFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_or.mindir";
static constexpr char kSingleSwitchFile[] = "/home/workspace/mindspore_dataset/mindir/control/switch_layer_net.mindir";
static constexpr float kConstValue = 0.1234;
static const std::vector<float> input_data(2 * 3 * 4 * 5, kConstValue);
class TestControl : public ST::Common {
public:
TestControl() {}
};
TEST_F(TestControl, InferIfbyIf) {
auto context = ContextAutoSet();
Graph graph;
ASSERT_TRUE(Serialization::Load(kIfbyIfFile, ModelType::kMindIR, &graph));
Model control_model;
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
// assert inputs
std::vector<MSTensor> inputs_before = control_model.GetInputs();
ASSERT_EQ(5, inputs_before.size());
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeBool);
EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeBool);
EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeFloat32);
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float));
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float));
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(bool));
ASSERT_EQ(inputs_before[3].DataSize(), sizeof(bool));
ASSERT_EQ(inputs_before[4].DataSize(), sizeof(float) * input_data.size());
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
ASSERT_EQ(inputs_before[3].Shape().size(), 1);
EXPECT_EQ(inputs_before[3].Shape()[0], 1);
ASSERT_EQ(inputs_before[4].Shape().size(), 4);
EXPECT_EQ(inputs_before[4].Shape()[0], 2);
EXPECT_EQ(inputs_before[4].Shape()[1], 3);
EXPECT_EQ(inputs_before[4].Shape()[2], 4);
EXPECT_EQ(inputs_before[4].Shape()[3], 5);
// prepare input
std::vector<MSTensor> outputs;
std::vector<MSTensor> inputs;
float x = 2.345678, y = 1.234567;
bool cond1 = true, cond2 = false;
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
sizeof(float));
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
sizeof(float));
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &cond1,
sizeof(bool));
inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &cond2,
sizeof(bool));
inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), input_data.data(),
sizeof(float) * input_data.size());
// infer
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
// assert output
ASSERT_TRUE(outputs.size() == 1);
auto out = outputs[0];
ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
auto out_data = out.Data();
auto p = reinterpret_cast<const float *>(out_data.get());
for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
ASSERT_LE(std::abs(p[i] - kConstValue * 24), 1e-3);
}
}
TEST_F(TestControl, InferSimpleWhile) {
auto context = ContextAutoSet();
Graph graph;
ASSERT_TRUE(Serialization::Load(kSimpleWhileFile, ModelType::kMindIR, &graph));
Model control_model;
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
// assert inputs
std::vector<MSTensor> inputs_before = control_model.GetInputs();
ASSERT_EQ(3, inputs_before.size());
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeBool);
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeBool);
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeFloat32);
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(bool));
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(bool));
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(float) * input_data.size());
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
ASSERT_EQ(inputs_before[2].Shape().size(), 4);
EXPECT_EQ(inputs_before[2].Shape()[0], 2);
EXPECT_EQ(inputs_before[2].Shape()[1], 3);
EXPECT_EQ(inputs_before[2].Shape()[2], 4);
EXPECT_EQ(inputs_before[2].Shape()[3], 5);
// prepare input
std::vector<MSTensor> outputs;
std::vector<MSTensor> inputs;
{
bool x = true, y = false;
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
sizeof(bool));
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
sizeof(bool));
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(),
input_data.data(), sizeof(float) * input_data.size());
}
// infer
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
// assert output
ASSERT_TRUE(outputs.size() == 1);
auto out = outputs[0];
ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
auto out_data = out.Data();
auto p = reinterpret_cast<const float *>(out_data.get());
for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
ASSERT_LE(std::abs(p[i] - kConstValue * 3), 1e-3);
}
}
TEST_F(TestControl, InferRecursive) {
auto context = ContextAutoSet();
Graph graph;
ASSERT_TRUE(Serialization::Load(kRecursiveFile, ModelType::kMindIR, &graph));
Model control_model;
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
// assert inputs
std::vector<MSTensor> inputs_before = control_model.GetInputs();
ASSERT_EQ(1, inputs_before.size());
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
// prepare input
std::vector<MSTensor> outputs;
std::vector<MSTensor> inputs;
{
int32_t x = 7;
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
sizeof(int32_t));
}
// infer
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
// assert output
ASSERT_TRUE(outputs.size() == 1);
auto out = outputs[0];
ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
auto out_data = out.Data();
auto p = reinterpret_cast<const int32_t *>(out_data.get());
ASSERT_EQ(*p, 21);
}
TEST_F(TestControl, InferMixedWhileIf) {
auto context = ContextAutoSet();
Graph graph;
ASSERT_TRUE(Serialization::Load(kMixIfWhileFile, ModelType::kMindIR, &graph));
Model control_model;
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
// assert inputs
std::vector<MSTensor> inputs_before = control_model.GetInputs();
ASSERT_EQ(inputs_before.size(), 5);
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeInt32);
EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeInt32);
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[3].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[4].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
ASSERT_EQ(inputs_before[3].Shape().size(), 1);
EXPECT_EQ(inputs_before[3].Shape()[0], 1);
ASSERT_EQ(inputs_before[4].Shape().size(), 1);
EXPECT_EQ(inputs_before[4].Shape()[0], 1);
// prepare input
std::vector<MSTensor> outputs;
std::vector<MSTensor> inputs;
{
int32_t x = 2, y = 14, z = 1, c2 = 14, c4 = 0;
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
sizeof(int32_t));
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
sizeof(int32_t));
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
sizeof(int32_t));
inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &c2,
sizeof(int32_t));
inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), &c4,
sizeof(int32_t));
}
// infer
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
// assert output
ASSERT_TRUE(outputs.size() == 1);
auto out = outputs[0];
ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
auto out_data = out.Data();
auto p = reinterpret_cast<const int32_t *>(out_data.get());
ASSERT_EQ(*p, 350);
}
TEST_F(TestControl, InferSingleFor) {
auto context = ContextAutoSet();
Graph graph;
ASSERT_TRUE(Serialization::Load(kSingleForFile, ModelType::kMindIR, &graph));
Model control_model;
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
// assert inputs
std::vector<MSTensor> inputs_before = control_model.GetInputs();
ASSERT_EQ(inputs_before.size(), 3);
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
// prepare input
std::vector<MSTensor> outputs;
std::vector<MSTensor> inputs;
{
int32_t x = 2, y = 5, z = 4;
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
sizeof(int32_t));
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
sizeof(int32_t));
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
sizeof(int32_t));
}
// infer
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
// assert output
ASSERT_TRUE(outputs.size() == 1);
auto out = outputs[0];
ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
auto out_data = out.Data();
auto p = reinterpret_cast<const int32_t *>(out_data.get());
ASSERT_EQ(*p, 125);
}
TEST_F(TestControl, InferSingleOr) {
auto context = ContextAutoSet();
Graph graph;
ASSERT_TRUE(Serialization::Load(kSingleOrFile, ModelType::kMindIR, &graph));
Model control_model;
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
// assert inputs
std::vector<MSTensor> inputs_before = control_model.GetInputs();
ASSERT_EQ(inputs_before.size(), 2);
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 2);
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float) * 2);
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
EXPECT_EQ(inputs_before[0].Shape()[0], 2);
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
EXPECT_EQ(inputs_before[1].Shape()[0], 2);
// prepare input
std::vector<MSTensor> outputs;
std::vector<MSTensor> inputs;
{
static const std::vector<float> input_data1 = {0, 1};
static const std::vector<float> input_data2 = {0, 0};
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
input_data1.data(), sizeof(float) * input_data1.size());
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(),
input_data2.data(), sizeof(int32_t) * input_data2.size());
}
// infer
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
// assert output
ASSERT_TRUE(outputs.size() == 1);
auto out = outputs[0];
ASSERT_TRUE(out.DataSize() == sizeof(float));
auto out_data = out.Data();
auto p = reinterpret_cast<const float *>(out_data.get());
ASSERT_EQ(*p, 1);
}
TEST_F(TestControl, InferSingleSwitch) {
auto context = ContextAutoSet();
Graph graph;
ASSERT_TRUE(Serialization::Load(kSingleSwitchFile, ModelType::kMindIR, &graph));
Model control_model;
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
// assert inputs
std::vector<MSTensor> inputs_before = control_model.GetInputs();
ASSERT_EQ(inputs_before.size(), 3);
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 224 * 224);
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
ASSERT_EQ(inputs_before[0].Shape().size(), 4);
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
EXPECT_EQ(inputs_before[0].Shape()[1], 1);
EXPECT_EQ(inputs_before[0].Shape()[2], 224);
EXPECT_EQ(inputs_before[0].Shape()[3], 224);
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
// prepare input
std::vector<MSTensor> outputs;
std::vector<MSTensor> inputs;
{
static const std::vector<float> input_data1(1 * 1 * 224 * 224, 1);
int32_t index1 = 0;
int32_t index2 = -1;
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
input_data1.data(), sizeof(float) * input_data1.size());
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &index1,
sizeof(int32_t));
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &index2,
sizeof(int32_t));
}
// infer
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
// assert output
ASSERT_TRUE(outputs.size() == 1);
auto out = outputs[0];
ASSERT_TRUE(out.DataSize() == sizeof(float) * 224 * 224);
auto out_data = out.Data();
auto p = reinterpret_cast<const float *>(out_data.get());
for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
ASSERT_EQ(p[i], 1);
}
}