forked from mindspore-Ecosystem/mindspore
!22915 310 support cond graph
Merge pull request !22915 from zhoufeng/310-support-cond-graph
This commit is contained in:
commit
516a74f985
|
@ -292,9 +292,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||||
/// \return The device id.
|
/// \return The device id.
|
||||||
uint32_t GetDeviceID() const;
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
inline void SetDumpConfigPath(const std::string &cfg_path);
|
|
||||||
inline std::string GetDumpConfigPath() const;
|
|
||||||
|
|
||||||
/// \brief Set AIPP configuration file path.
|
/// \brief Set AIPP configuration file path.
|
||||||
///
|
///
|
||||||
/// \param[in] cfg_path AIPP configuration file path.
|
/// \param[in] cfg_path AIPP configuration file path.
|
||||||
|
@ -379,9 +376,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||||
inline std::string GetBufferOptimizeMode() const;
|
inline std::string GetBufferOptimizeMode() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void SetDumpConfigPath(const std::vector<char> &cfg_path);
|
|
||||||
std::vector<char> GetDumpConfigPathChar() const;
|
|
||||||
|
|
||||||
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
|
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
|
||||||
std::vector<char> GetInsertOpConfigPathChar() const;
|
std::vector<char> GetInsertOpConfigPathChar() const;
|
||||||
|
|
||||||
|
@ -406,9 +400,6 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||||
std::vector<char> GetBufferOptimizeModeChar() const;
|
std::vector<char> GetBufferOptimizeModeChar() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
|
|
||||||
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||||
SetInsertOpConfigPath(StringToChar(cfg_path));
|
SetInsertOpConfigPath(StringToChar(cfg_path));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1121,7 +1121,8 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
|
||||||
return new_parameter;
|
return new_parameter;
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
|
||||||
|
bool common_opt) {
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
||||||
auto graph = NewKernelGraph();
|
auto graph = NewKernelGraph();
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
@ -1161,7 +1162,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
||||||
// Update Graph Dynamic Shape Attr
|
// Update Graph Dynamic Shape Attr
|
||||||
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
|
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
|
||||||
UpdateGraphAquireGilAttr(NOT_NULL(graph));
|
UpdateGraphAquireGilAttr(NOT_NULL(graph));
|
||||||
opt::BackendCommonOptimization(graph);
|
if (common_opt) {
|
||||||
|
opt::BackendCommonOptimization(graph);
|
||||||
|
}
|
||||||
graph->SetInputNodes();
|
graph->SetInputNodes();
|
||||||
SetInputNodeUsage(graph, manager);
|
SetInputNodeUsage(graph, manager);
|
||||||
graph->SetOptimizerFlag();
|
graph->SetOptimizerFlag();
|
||||||
|
|
|
@ -116,7 +116,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
|
|
||||||
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
|
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
|
||||||
|
|
||||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
|
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
|
||||||
|
bool common_opt = true);
|
||||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||||
std::vector<KernelGraphPtr> *all_out_graph);
|
std::vector<KernelGraphPtr> *all_out_graph);
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ if(ENABLE_D OR ENABLE_ACL)
|
||||||
|
|
||||||
if(NOT ENABLE_D)
|
if(NOT ENABLE_D)
|
||||||
list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
|
list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
|
||||||
|
list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_vm_obj>)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,6 @@ constexpr auto kModelOptionGPUTrtInferMode = "mindspore.option.gpu.trt_infer_mod
|
||||||
constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode";
|
constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode";
|
||||||
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
|
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
|
||||||
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
||||||
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
|
|
||||||
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
|
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
|
||||||
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
|
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
|
||||||
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
|
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
|
||||||
|
@ -193,16 +192,6 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||||
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
|
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
|
||||||
data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path);
|
|
||||||
}
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
|
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
|
|
||||||
return StringToChar(ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
||||||
|
|
|
@ -21,8 +21,8 @@ namespace mindspore {
|
||||||
std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
|
std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_;
|
||||||
std::mutex AclEnvGuard::global_acl_env_mutex_;
|
std::mutex AclEnvGuard::global_acl_env_mutex_;
|
||||||
|
|
||||||
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
|
AclEnvGuard::AclEnvGuard() {
|
||||||
errno_ = aclInit(cfg_file.data());
|
errno_ = aclInit(nullptr);
|
||||||
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) {
|
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) {
|
||||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||||
return;
|
return;
|
||||||
|
@ -38,18 +38,15 @@ AclEnvGuard::~AclEnvGuard() {
|
||||||
MS_LOG(INFO) << "Acl finalize success";
|
MS_LOG(INFO) << "Acl finalize success";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
|
std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv() {
|
||||||
std::shared_ptr<AclEnvGuard> acl_env;
|
std::shared_ptr<AclEnvGuard> acl_env;
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
|
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
|
||||||
acl_env = global_acl_env_;
|
acl_env = global_acl_env_;
|
||||||
if (acl_env != nullptr) {
|
if (acl_env != nullptr) {
|
||||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||||
if (!cfg_file.empty()) {
|
|
||||||
MS_LOG(WARNING) << "Dump config file option " << cfg_file << " is ignored.";
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
acl_env = std::make_shared<AclEnvGuard>(cfg_file);
|
acl_env = std::make_shared<AclEnvGuard>();
|
||||||
aclError ret = acl_env->GetErrno();
|
aclError ret = acl_env->GetErrno();
|
||||||
if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) {
|
if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) {
|
||||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||||
|
|
|
@ -23,10 +23,10 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class __attribute__((visibility("default"))) AclEnvGuard {
|
class __attribute__((visibility("default"))) AclEnvGuard {
|
||||||
public:
|
public:
|
||||||
explicit AclEnvGuard(std::string_view cfg_file);
|
explicit AclEnvGuard();
|
||||||
~AclEnvGuard();
|
~AclEnvGuard();
|
||||||
aclError GetErrno() const { return errno_; }
|
aclError GetErrno() const { return errno_; }
|
||||||
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file);
|
static std::shared_ptr<AclEnvGuard> GetAclEnv();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static std::shared_ptr<AclEnvGuard> global_acl_env_;
|
static std::shared_ptr<AclEnvGuard> global_acl_env_;
|
||||||
|
|
|
@ -91,7 +91,7 @@ Status AclGraphImpl::InitEnv() {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
acl_env_ = AclEnvGuard::GetAclEnv("");
|
acl_env_ = AclEnvGuard::GetAclEnv();
|
||||||
if (acl_env_ == nullptr) {
|
if (acl_env_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Acl init failed.";
|
MS_LOG(ERROR) << "Acl init failed.";
|
||||||
return kMCDeviceError;
|
return kMCDeviceError;
|
||||||
|
|
|
@ -165,7 +165,7 @@ Status ModelProcess::InitInputsBuffer() {
|
||||||
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
|
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
|
||||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||||
const char *input_name_char = aclmdlGetInputNameByIndex(model_desc_, i);
|
const char *input_name_char = aclmdlGetInputNameByIndex(model_desc_, i);
|
||||||
std::string input_name = (input_name_char == nullptr) ? input_name_char : std::string();
|
std::string input_name = (input_name_char != nullptr) ? input_name_char : std::string();
|
||||||
if (input_name.empty()) {
|
if (input_name.empty()) {
|
||||||
MS_LOG(WARNING) << "Get name of input " << i << " failed.";
|
MS_LOG(WARNING) << "Get name of input " << i << " failed.";
|
||||||
}
|
}
|
||||||
|
@ -249,7 +249,7 @@ Status ModelProcess::InitOutputsBuffer() {
|
||||||
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
|
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
|
||||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||||
const char *output_name_char = aclmdlGetOutputNameByIndex(model_desc_, i);
|
const char *output_name_char = aclmdlGetOutputNameByIndex(model_desc_, i);
|
||||||
std::string output_name = (output_name_char == nullptr) ? output_name_char : std::string();
|
std::string output_name = (output_name_char != nullptr) ? output_name_char : std::string();
|
||||||
if (output_name.empty()) {
|
if (output_name.empty()) {
|
||||||
MS_LOG(WARNING) << "Get name of output " << i << " failed.";
|
MS_LOG(WARNING) << "Get name of output " << i << " failed.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,8 +24,6 @@
|
||||||
#include "acl/acl_base.h"
|
#include "acl/acl_base.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
|
|
||||||
|
|
||||||
Status AclModel::Build() {
|
Status AclModel::Build() {
|
||||||
MS_LOG(INFO) << "Start build model.";
|
MS_LOG(INFO) << "Start build model.";
|
||||||
MS_EXCEPTION_IF_NULL(graph_);
|
MS_EXCEPTION_IF_NULL(graph_);
|
||||||
|
@ -42,13 +40,8 @@ Status AclModel::Build() {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_);
|
std::shared_ptr<AclModelOptions> options = std::make_shared<AclModelOptions>(model_context_);
|
||||||
MS_EXCEPTION_IF_NULL(options);
|
MS_EXCEPTION_IF_NULL(options);
|
||||||
std::string dump_cfg = options->GetDumpCfgPath();
|
|
||||||
if (!dump_cfg.empty()) {
|
|
||||||
MS_LOG(INFO) << "Options dump config file path " << dump_cfg;
|
|
||||||
(void)AclEnvGuard::GetAclEnv(dump_cfg);
|
|
||||||
}
|
|
||||||
std::string options_key = options->GenAclOptionsKey();
|
std::string options_key = options->GenAclOptionsKey();
|
||||||
std::shared_ptr<Graph> graph;
|
std::shared_ptr<Graph> graph;
|
||||||
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
|
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
|
||||||
|
@ -72,7 +65,7 @@ Status AclModel::Build() {
|
||||||
}
|
}
|
||||||
options->RenameInput(input_names);
|
options->RenameInput(input_names);
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
model_converter_.set_options(options.get());
|
model_converter_.set_options(options);
|
||||||
auto om_data = model_converter_.LoadMindIR(func_graph);
|
auto om_data = model_converter_.LoadMindIR(func_graph);
|
||||||
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||||
|
|
|
@ -47,7 +47,7 @@ class AclModel : public ModelImpl {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ModelConverter model_converter_;
|
ModelConverter model_converter_;
|
||||||
std::unique_ptr<AclModelOptions> options_;
|
std::shared_ptr<AclModelOptions> options_;
|
||||||
std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_;
|
std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,475 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cxx_api/model/acl/acl_model_multi.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "backend/session/session_basic.h"
|
||||||
|
#include "backend/session/session_factory.h"
|
||||||
|
#include "cxx_api/factory.h"
|
||||||
|
#include "vm/backend.h"
|
||||||
|
#include "vm/transform.h"
|
||||||
|
#include "acl/acl_rt.h"
|
||||||
|
#include "mindspore/core/load_mindir/infer_mindir.h"
|
||||||
|
#include "debug/trace.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class MSTensorRef : public BaseRef {
|
||||||
|
public:
|
||||||
|
static VectorRef Convert(const std::vector<MSTensor> &tensors) {
|
||||||
|
VectorRef res;
|
||||||
|
std::transform(tensors.begin(), tensors.end(), std::back_inserter(res),
|
||||||
|
[](const MSTensor &t) { return MSTensorRef(t); });
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<MSTensor> Convert(const BaseRef &args) {
|
||||||
|
std::vector<MSTensor> res;
|
||||||
|
if (utils::isa<VectorRef>(args)) {
|
||||||
|
VectorRef args_vec = utils::cast<VectorRef>(args);
|
||||||
|
for (size_t i = 0; i < args_vec.size(); ++i) {
|
||||||
|
const auto &item = args_vec[i];
|
||||||
|
if (!utils::isa<MSTensorRef>(item)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid item " << item.ToString() << " at index " << i;
|
||||||
|
}
|
||||||
|
auto wrapper = utils::cast<MSTensorRef>(item);
|
||||||
|
res.push_back(wrapper.ms_tensor_);
|
||||||
|
}
|
||||||
|
} else if (utils::isa<MSTensorRef>(args)) {
|
||||||
|
auto wrapper = utils::cast<MSTensorRef>(args);
|
||||||
|
res.push_back(wrapper.ms_tensor_);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() << " must be MSTensorRef or VectorRef{MSTensorRef...}";
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_DECLARE_PARENT(MSTensorRef, BaseRef);
|
||||||
|
explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {}
|
||||||
|
~MSTensorRef() override = default;
|
||||||
|
|
||||||
|
const MSTensor &GetTensor() const { return ms_tensor_; }
|
||||||
|
std::shared_ptr<Base> copy() const override {
|
||||||
|
MSTensor *tensor = ms_tensor_.Clone();
|
||||||
|
auto res = std::make_shared<MSTensorRef>(static_cast<const MSTensor &>(*tensor));
|
||||||
|
MSTensor::DestroyTensorPtr(tensor);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t type() const override { return tid(); }
|
||||||
|
std::string ToString() const override { return ms_tensor_.Name(); }
|
||||||
|
bool operator==(const BaseRef &other) const override {
|
||||||
|
if (!utils::isa<MSTensorRef>(other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *this == utils::cast<MSTensorRef>(other);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(MSTensorRef &other) {
|
||||||
|
return (ms_tensor_.Name() == other.ms_tensor_.Name()) && (ms_tensor_.Shape() == other.ms_tensor_.Shape()) &&
|
||||||
|
(ms_tensor_.MutableData() == other.ms_tensor_.MutableData()) &&
|
||||||
|
(ms_tensor_.DataSize() == other.ms_tensor_.DataSize()) &&
|
||||||
|
(ms_tensor_.DataType() == other.ms_tensor_.DataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MSTensor ms_tensor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MultiGraphAclSession : public session::SessionBasic {
|
||||||
|
public:
|
||||||
|
MultiGraphAclSession() = default;
|
||||||
|
~MultiGraphAclSession() override = default;
|
||||||
|
void Init(uint32_t device_id) override;
|
||||||
|
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||||
|
void RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs);
|
||||||
|
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<GraphId, GraphCell> graphs_ = {};
|
||||||
|
std::shared_ptr<AclModelOptions> options_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); }
|
||||||
|
|
||||||
|
GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||||
|
class FirstGraphModeGuard {
|
||||||
|
public:
|
||||||
|
explicit FirstGraphModeGuard(const std::shared_ptr<AclModelOptions> &options) : options_(options) {
|
||||||
|
if (options_ != nullptr) {
|
||||||
|
options_->SetFirstGraph(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
~FirstGraphModeGuard() {
|
||||||
|
if (options_ != nullptr) {
|
||||||
|
options_->SetFirstGraph(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<AclModelOptions> options_;
|
||||||
|
};
|
||||||
|
MS_LOG(INFO) << "Start MultiGraph Compile.";
|
||||||
|
auto kernel_graph = ConstructKernelGraph(lst, outputs, false);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
ModelConverter model_converter_;
|
||||||
|
model_converter_.set_options(options_);
|
||||||
|
FirstGraphModeGuard guard(options_);
|
||||||
|
auto om_data = model_converter_.LoadMindIR(kernel_graph);
|
||||||
|
if (om_data.Data() == nullptr || om_data.DataSize() == 0) {
|
||||||
|
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
graphs_[kernel_graph->graph_id()] = GraphCell(graph);
|
||||||
|
MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id();
|
||||||
|
return kernel_graph->graph_id();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor> &inputs, VectorRef *outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
|
MS_LOG(INFO) << "Start run graph " << graph_id;
|
||||||
|
auto iter = graphs_.find(graph_id);
|
||||||
|
if (iter == graphs_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " not found.";
|
||||||
|
}
|
||||||
|
std::vector<MSTensor> out_tensors;
|
||||||
|
auto ret = iter->second.Run(inputs, &out_tensors);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
|
||||||
|
}
|
||||||
|
(*outputs) = MSTensorRef::Convert(out_tensors);
|
||||||
|
}
|
||||||
|
|
||||||
|
class AclBackend : public compile::MsBackend {
|
||||||
|
public:
|
||||||
|
AclBackend(const std::string &name, const std::string &target, const std::shared_ptr<AclModelOptions> &options)
|
||||||
|
: MsBackend(name, target, options->GetDeviceID()) {
|
||||||
|
auto session = std::dynamic_pointer_cast<MultiGraphAclSession>(MsBackend::target_sess_);
|
||||||
|
MS_EXCEPTION_IF_NULL(session);
|
||||||
|
session->SetOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
~AclBackend() override = default;
|
||||||
|
|
||||||
|
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) override {
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
for (const auto &arg : args) {
|
||||||
|
if (!utils::isa<MSTensorRef>(arg)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid item " << arg.ToString();
|
||||||
|
}
|
||||||
|
auto wrapper = utils::cast<MSTensorRef>(arg);
|
||||||
|
inputs.emplace_back(wrapper.GetTensor());
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorRef outputs;
|
||||||
|
MS_EXCEPTION_IF_NULL(target_sess_);
|
||||||
|
auto exec_sess = std::dynamic_pointer_cast<MultiGraphAclSession>(target_sess_);
|
||||||
|
MS_EXCEPTION_IF_NULL(exec_sess);
|
||||||
|
exec_sess->RunGraph(g, inputs, &outputs);
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetCond(const BaseRef &c, bool *value) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (!utils::isa<MSTensorRef>(c)) {
|
||||||
|
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto wrapper = utils::cast<MSTensorRef>(c);
|
||||||
|
if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) {
|
||||||
|
MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto data = wrapper.GetTensor().Data();
|
||||||
|
if (data == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*value) = *reinterpret_cast<const bool *>(data.get());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetIndex(const BaseRef &c, int64_t *value) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (!utils::isa<MSTensorRef>(c)) {
|
||||||
|
MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto wrapper = utils::cast<MSTensorRef>(c);
|
||||||
|
if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) {
|
||||||
|
auto data = wrapper.GetTensor().Data();
|
||||||
|
if (data == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto value_int32 = *reinterpret_cast<const int32_t *>(data.get());
|
||||||
|
(*value) = static_cast<int64_t>(value_int32);
|
||||||
|
return true;
|
||||||
|
} else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) {
|
||||||
|
auto data = wrapper.GetTensor().Data();
|
||||||
|
if (data == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*value) = *reinterpret_cast<const int64_t *>(data.get());
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Index must be Int type.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AclCompileGraph : public compile::CompileGraph {
|
||||||
|
public:
|
||||||
|
explicit AclCompileGraph(const std::shared_ptr<compile::MsBackend> &backend,
|
||||||
|
const std::vector<PrimitivePtr> &cut_list)
|
||||||
|
: CompileGraph(backend, cut_list) {}
|
||||||
|
~AclCompileGraph() override = default;
|
||||||
|
|
||||||
|
void AddInst(const compile::Instruction &inst, const MSTensorRef &arg) {
|
||||||
|
VectorRef args;
|
||||||
|
args.push_back(arg);
|
||||||
|
compile::CompileGraph::AddInst(inst, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Ref(const AnfNodePtr &node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
|
||||||
|
if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
|
||||||
|
if (IsValueNode<FuncGraph>(node)) {
|
||||||
|
MS_LOG(DEBUG) << "Push graph.";
|
||||||
|
compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node));
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Push.";
|
||||||
|
if (IsValueNode<Primitive>(node)) {
|
||||||
|
MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||||
|
} else if (IsValueNode<tensor::Tensor>(node)) {
|
||||||
|
auto tensor_node = std::dynamic_pointer_cast<tensor::Tensor>(node->cast<ValueNodePtr>()->value());
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_node);
|
||||||
|
std::string name = "";
|
||||||
|
std::vector<int64_t> shape = tensor_node->shape_c();
|
||||||
|
DataType type = static_cast<DataType>(tensor_node->data_type_c());
|
||||||
|
auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size());
|
||||||
|
MSTensorRef mstensor_ref(*mstensor_node);
|
||||||
|
AddInst(compile::Instruction::kPush, mstensor_ref);
|
||||||
|
MSTensor::DestroyTensorPtr(mstensor_node);
|
||||||
|
} else {
|
||||||
|
compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Push(node);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
|
||||||
|
<< ", return: " << slots_[node] - height_;
|
||||||
|
return slots_[node] - height_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AclCompileGraphs : public compile::CompileGraphs {
|
||||||
|
public:
|
||||||
|
explicit AclCompileGraphs(const std::shared_ptr<compile::MsBackend> &backend,
|
||||||
|
const std::vector<PrimitivePtr> &cut_list)
|
||||||
|
: CompileGraphs(backend, cut_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(backend);
|
||||||
|
MS_LOG(DEBUG) << "Start vm: " << backend->name();
|
||||||
|
transform_ = std::make_shared<AclCompileGraph>(backend, cut_list);
|
||||||
|
Reset();
|
||||||
|
}
|
||||||
|
~AclCompileGraphs() override = default;
|
||||||
|
void Compile(const FuncGraphPtr &graph) override {
|
||||||
|
MS_LOG(DEBUG) << "Start";
|
||||||
|
mapping_[graph] = SizeToLong(insts_.size());
|
||||||
|
if (transform_ != nullptr) {
|
||||||
|
auto insts = transform_->Run(graph, false);
|
||||||
|
if (!insts.empty()) {
|
||||||
|
(void)insts_.insert(insts_.end(), insts.begin(), insts.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<compile::MsBackend> CreateBackend(const std::shared_ptr<AclModelOptions> &options) {
|
||||||
|
MS_EXCEPTION_IF_NULL(options);
|
||||||
|
return std::make_shared<AclBackend>(kMsConvert, kDavinciMultiGraphInferenceDevice, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasMultiGraph(const FuncGraphPtr &fg) {
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
std::vector<AnfNodePtr> all_nodes = TopoSort(fg->get_return());
|
||||||
|
for (const auto &node : all_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (IsValueNode<FuncGraph>(node)) {
|
||||||
|
MS_LOG(INFO) << fg->ToString() << " has FuncGraph node " << node->DebugString() << " is multi graph.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
Status AclModelMulti::Build() {
|
||||||
|
if (!is_multi_graph_.has_value()) {
|
||||||
|
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_multi_graph_.value()) {
|
||||||
|
return AclModel::Build();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vm_ != nullptr) {
|
||||||
|
MS_LOG(INFO) << "Multi graph model has been built, skip.";
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Start build multi graph model.";
|
||||||
|
// perpare func graph
|
||||||
|
auto manager = MakeManager();
|
||||||
|
manager->AddFuncGraph(ModelImpl::GetFuncGraph());
|
||||||
|
ModelImpl::GetFuncGraph()->set_manager(manager);
|
||||||
|
// set inputs
|
||||||
|
SetInputs();
|
||||||
|
// infer mindir
|
||||||
|
abstract::AbstractBasePtrList broaded_args;
|
||||||
|
auto fg = ModelImpl::GetFuncGraph();
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
const auto &inputs = fg->get_inputs();
|
||||||
|
(void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(broaded_args),
|
||||||
|
[](const AnfNodePtr &n) -> AbstractBasePtr {
|
||||||
|
MS_EXCEPTION_IF_NULL(n);
|
||||||
|
auto abstract = n->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
if (abstract->GetValueTrack() != kAnyValue) {
|
||||||
|
return abstract->Broaden();
|
||||||
|
}
|
||||||
|
return abstract;
|
||||||
|
});
|
||||||
|
(void)InferMindir(ModelImpl::GetFuncGraph(), broaded_args);
|
||||||
|
// create vm
|
||||||
|
auto backend = CreateBackend(std::make_shared<AclModelOptions>(model_context_));
|
||||||
|
auto context_ptr = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
backend->set_is_multi_graph_sink(false);
|
||||||
|
context_ptr->set_param<std::string>(MS_CTX_DEVICE_TARGET, kDavinciMultiGraphInferenceDevice);
|
||||||
|
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||||
|
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
|
||||||
|
auto compile = std::make_shared<AclCompileGraphs>(backend, compile::GetMsNonlinearOps());
|
||||||
|
|
||||||
|
vm_ = compile->CompileAndLink(ModelImpl::GetFuncGraph());
|
||||||
|
backend_ = std::move(backend);
|
||||||
|
MS_LOG(INFO) << "Build multi graph model success.";
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AclModelMulti::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
|
if (!is_multi_graph_.has_value()) {
|
||||||
|
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_multi_graph_.value()) {
|
||||||
|
return AclModel::Predict(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Build();
|
||||||
|
MS_LOG(INFO) << "Start predict multi graph model.";
|
||||||
|
MS_EXCEPTION_IF_NULL(vm_);
|
||||||
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
|
try {
|
||||||
|
(*outputs) = MSTensorRef::Convert(vm_->Eval(MSTensorRef::Convert(inputs)));
|
||||||
|
} catch (const std::exception &ex) {
|
||||||
|
MS_LOG(ERROR) << "Predict Failed, error: " << ex.what();
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs.size() != inputs_.size() && !inputs_.empty() != 0) {
|
||||||
|
MS_LOG(ERROR) << "Input Size is wrong.";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs_.empty()) {
|
||||||
|
inputs_ = inputs;
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||||
|
auto input_tensor = MSTensor::CreateTensor(inputs_[i].Name(), inputs_[i].DataType(), inputs_[i].Shape(),
|
||||||
|
inputs[i].Data().get(), inputs[i].DataSize());
|
||||||
|
inputs_[i] = (*input_tensor);
|
||||||
|
MSTensor::DestroyTensorPtr(input_tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs_ = *outputs;
|
||||||
|
MS_LOG(INFO) << "Predict multi graph model success.";
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AclModelMulti::SetInputs() {
|
||||||
|
if (inputs_.empty()) {
|
||||||
|
auto fg = ModelImpl::GetFuncGraph();
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
const auto &inputs = fg->get_inputs();
|
||||||
|
for (const auto &in : inputs) {
|
||||||
|
auto input_param = std::dynamic_pointer_cast<Parameter>(in);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_param);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_param->abstract());
|
||||||
|
auto input_value = input_param->abstract()->GetValueTrack();
|
||||||
|
auto tensor = input_value->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
|
||||||
|
std::vector<int64_t> shape = tensor->shape_c();
|
||||||
|
auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast<DataType>(tensor->data_type_c()),
|
||||||
|
shape, nullptr, tensor->Size());
|
||||||
|
inputs_.emplace_back(*input_tensor);
|
||||||
|
MSTensor::DestroyTensorPtr(input_tensor);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "inputs_ has been set.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> AclModelMulti::GetInputs() {
|
||||||
|
if (!is_multi_graph_.has_value()) {
|
||||||
|
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_multi_graph_.value()) {
|
||||||
|
return AclModel::GetInputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> AclModelMulti::GetOutputs() {
|
||||||
|
if (!is_multi_graph_.has_value()) {
|
||||||
|
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_multi_graph_.value()) {
|
||||||
|
return AclModel::GetOutputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace session {
|
||||||
|
MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession);
|
||||||
|
} // namespace session
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H
|
||||||
|
#define MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H
|
||||||
|
|
||||||
|
#include "cxx_api/model/acl/acl_model.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace compile {
|
||||||
|
class MsBackend;
|
||||||
|
class FinalVM;
|
||||||
|
} // namespace compile
|
||||||
|
|
||||||
|
class AclModelMulti : public AclModel {
|
||||||
|
public:
|
||||||
|
AclModelMulti() : AclModel(), is_multi_graph_(std::nullopt) {}
|
||||||
|
~AclModelMulti() = default;
|
||||||
|
|
||||||
|
Status Build() override;
|
||||||
|
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||||
|
|
||||||
|
std::vector<MSTensor> GetInputs() override;
|
||||||
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetInputs();
|
||||||
|
|
||||||
|
std::optional<bool> is_multi_graph_;
|
||||||
|
std::shared_ptr<compile::MsBackend> backend_;
|
||||||
|
std::shared_ptr<compile::FinalVM> vm_;
|
||||||
|
std::vector<MSTensor> inputs_ = {};
|
||||||
|
std::vector<MSTensor> outputs_ = {};
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H
|
|
@ -14,6 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "cxx_api/model/acl/acl_model_options.h"
|
#include "cxx_api/model/acl/acl_model_options.h"
|
||||||
|
#include <set>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "external/ge/ge_api_types.h"
|
#include "external/ge/ge_api_types.h"
|
||||||
|
@ -54,7 +55,6 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
||||||
op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode();
|
op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode();
|
||||||
fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath();
|
fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath();
|
||||||
device_id_ = ascend310_info->GetDeviceID();
|
device_id_ = ascend310_info->GetDeviceID();
|
||||||
dump_cfg_path_ = ascend310_info->GetDumpConfigPath();
|
|
||||||
buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode();
|
buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode();
|
||||||
const char *soc_name = aclrtGetSocName();
|
const char *soc_name = aclrtGetSocName();
|
||||||
if (soc_name == nullptr) {
|
if (soc_name == nullptr) {
|
||||||
|
@ -98,6 +98,14 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
|
||||||
{&dynamic_batch_size_, ge::ir_option::DYNAMIC_BATCH_SIZE},
|
{&dynamic_batch_size_, ge::ir_option::DYNAMIC_BATCH_SIZE},
|
||||||
{&dynamic_image_size_, ge::ir_option::DYNAMIC_IMAGE_SIZE}};
|
{&dynamic_image_size_, ge::ir_option::DYNAMIC_IMAGE_SIZE}};
|
||||||
|
|
||||||
|
const std::set<std::string> first_graph_options = {
|
||||||
|
ge::ir_option::INSERT_OP_FILE,
|
||||||
|
ge::ir_option::INPUT_FORMAT,
|
||||||
|
ge::ir_option::INPUT_SHAPE,
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::set<std::string> multi_graph_unsupported_options = {ge::ir_option::OUTPUT_TYPE};
|
||||||
|
|
||||||
std::map<std::string, std::string> init_options;
|
std::map<std::string, std::string> init_options;
|
||||||
std::map<std::string, std::string> build_options;
|
std::map<std::string, std::string> build_options;
|
||||||
for (auto [ms_option, acl_option_key] : init_options_map) {
|
for (auto [ms_option, acl_option_key] : init_options_map) {
|
||||||
|
@ -115,6 +123,20 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
|
||||||
MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option;
|
MS_LOG(INFO) << "Option " << acl_option_key << " : " << *ms_option;
|
||||||
build_options.emplace(acl_option_key, *ms_option);
|
build_options.emplace(acl_option_key, *ms_option);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// first_graph_flag has value means being multi graph mode
|
||||||
|
if (first_graph_flag_.has_value()) {
|
||||||
|
for (const auto &option : multi_graph_unsupported_options) {
|
||||||
|
build_options.erase(option);
|
||||||
|
}
|
||||||
|
// non-input graph
|
||||||
|
if (!first_graph_flag_) {
|
||||||
|
for (const auto &option : first_graph_options) {
|
||||||
|
build_options.erase(option);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {init_options, build_options};
|
return {init_options, build_options};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
|
@ -32,11 +33,11 @@ class AclModelOptions {
|
||||||
~AclModelOptions() = default;
|
~AclModelOptions() = default;
|
||||||
std::string GenAclOptionsKey() const;
|
std::string GenAclOptionsKey() const;
|
||||||
uint32_t GetDeviceID() const { return device_id_; }
|
uint32_t GetDeviceID() const { return device_id_; }
|
||||||
std::string GetDumpCfgPath() const { return dump_cfg_path_; }
|
void RenameInput(const std::vector<std::string> &);
|
||||||
void RenameInput(const std::vector<std::string> &name);
|
|
||||||
|
|
||||||
// return tuple<init_options, build_options>
|
// return tuple<init_options, build_options>
|
||||||
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;
|
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;
|
||||||
|
void SetFirstGraph(bool is_first_graph) { first_graph_flag_ = is_first_graph; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string output_node_; // todo: at convert.cc::BuildGraph(), no atc options
|
std::string output_node_; // todo: at convert.cc::BuildGraph(), no atc options
|
||||||
|
@ -55,7 +56,7 @@ class AclModelOptions {
|
||||||
std::map<int, std::vector<int>> input_shape_map_;
|
std::map<int, std::vector<int>> input_shape_map_;
|
||||||
// other options
|
// other options
|
||||||
uint32_t device_id_;
|
uint32_t device_id_;
|
||||||
std::string dump_cfg_path_;
|
std::optional<bool> first_graph_flag_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -226,8 +226,9 @@ Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
|
||||||
|
|
||||||
std::map<std::string, std::string> init_options;
|
std::map<std::string, std::string> init_options;
|
||||||
std::map<std::string, std::string> build_options;
|
std::map<std::string, std::string> build_options;
|
||||||
if (options_ != nullptr) {
|
auto option = options_.lock();
|
||||||
std::tie(init_options, build_options) = options_->GenAclOptions();
|
if (option != nullptr) {
|
||||||
|
std::tie(init_options, build_options) = option->GenAclOptions();
|
||||||
}
|
}
|
||||||
|
|
||||||
return BuildAirModel(df_graph, init_options, build_options);
|
return BuildAirModel(df_graph, init_options, build_options);
|
||||||
|
|
|
@ -30,12 +30,12 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelConverter {
|
class ModelConverter {
|
||||||
public:
|
public:
|
||||||
ModelConverter() : options_(nullptr) {}
|
ModelConverter() : options_() {}
|
||||||
~ModelConverter() = default;
|
~ModelConverter() = default;
|
||||||
|
|
||||||
Buffer LoadMindIR(const FuncGraphPtr &func_graph);
|
Buffer LoadMindIR(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
void set_options(AclModelOptions *options) { options_ = options; }
|
void set_options(const std::weak_ptr<AclModelOptions> &options) { options_ = options; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph);
|
transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph);
|
||||||
|
@ -43,7 +43,7 @@ class ModelConverter {
|
||||||
const std::map<std::string, std::string> &build_options);
|
const std::map<std::string, std::string> &build_options);
|
||||||
Buffer LoadAscendIRInner(const Buffer &model_data);
|
Buffer LoadAscendIRInner(const Buffer &model_data);
|
||||||
|
|
||||||
AclModelOptions *options_;
|
std::weak_ptr<AclModelOptions> options_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H
|
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H
|
||||||
|
|
|
@ -24,7 +24,6 @@ namespace dataset {
|
||||||
|
|
||||||
Status AscendResource::InitResource(uint32_t device_id) {
|
Status AscendResource::InitResource(uint32_t device_id) {
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
|
||||||
resource.deviceIds.insert(device_id);
|
resource.deviceIds.insert(device_id);
|
||||||
ascend_resource_ = ResourceManager::GetInstance();
|
ascend_resource_ = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = ascend_resource_->InitResource(resource);
|
APP_ERROR ret = ascend_resource_->InitResource(resource);
|
||||||
|
|
|
@ -77,7 +77,6 @@ Status DvppCropJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shared
|
||||||
imageinfo.heightStride = yuv_shape_[3];
|
imageinfo.heightStride = yuv_shape_[3];
|
||||||
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
|
||||||
resource.deviceIds.insert(0);
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
|
|
|
@ -70,7 +70,6 @@ Status DvppDecodeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
||||||
imageInfo.lenOfByte = filesize;
|
imageInfo.lenOfByte = filesize;
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
|
||||||
resource.deviceIds.insert(0);
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
|
|
|
@ -68,7 +68,6 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr<Tensor> &input, std::share
|
||||||
imageInfo.lenOfByte = filesize;
|
imageInfo.lenOfByte = filesize;
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
|
||||||
resource.deviceIds.insert(0);
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
|
|
|
@ -69,7 +69,6 @@ Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr<Tensor> &input,
|
||||||
imageInfo.lenOfByte = filesize;
|
imageInfo.lenOfByte = filesize;
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
|
||||||
resource.deviceIds.insert(0);
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
|
|
|
@ -68,7 +68,6 @@ Status DvppDecodeResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std
|
||||||
imageInfo.lenOfByte = filesize;
|
imageInfo.lenOfByte = filesize;
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
|
||||||
resource.deviceIds.insert(0);
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
|
|
|
@ -78,7 +78,6 @@ Status DvppResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
||||||
imageinfo.heightStride = yuv_shape_[3];
|
imageinfo.heightStride = yuv_shape_[3];
|
||||||
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
|
||||||
resource.deviceIds.insert(0);
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
|
|
|
@ -85,22 +85,11 @@ APP_ERROR ResourceManager::InitResource(ResourceInfo &resourceInfo) {
|
||||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||||
return APP_ERR_OK;
|
return APP_ERR_OK;
|
||||||
}
|
}
|
||||||
std::string &aclConfigPath = resourceInfo.aclConfigPath;
|
|
||||||
APP_ERROR ret = APP_ERR_OK;
|
APP_ERROR ret = APP_ERR_OK;
|
||||||
if (aclConfigPath.length() == 0) {
|
acl_env_ = mindspore::AclEnvGuard::GetAclEnv();
|
||||||
// Init acl without aclconfig
|
|
||||||
acl_env_ = mindspore::AclEnvGuard::GetAclEnv("");
|
|
||||||
} else {
|
|
||||||
ret = ExistFile(aclConfigPath);
|
|
||||||
if (ret != APP_ERR_OK) {
|
|
||||||
MS_LOG(ERROR) << "Acl config file not exist, ret = " << ret << ".";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
acl_env_ = mindspore::AclEnvGuard::GetAclEnv(aclConfigPath);
|
|
||||||
}
|
|
||||||
if (acl_env_ == nullptr) {
|
if (acl_env_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Failed to init acl.";
|
MS_LOG(ERROR) << "Failed to init acl.";
|
||||||
return ret;
|
return APP_ERR_COMM_FAILURE;
|
||||||
}
|
}
|
||||||
std::copy(resourceInfo.deviceIds.begin(), resourceInfo.deviceIds.end(), std::back_inserter(deviceIds_));
|
std::copy(resourceInfo.deviceIds.begin(), resourceInfo.deviceIds.end(), std::back_inserter(deviceIds_));
|
||||||
MS_LOG(INFO) << "Initialized acl successfully.";
|
MS_LOG(INFO) << "Initialized acl successfully.";
|
||||||
|
|
|
@ -53,7 +53,6 @@ struct DeviceResInfo {
|
||||||
|
|
||||||
struct ResourceInfo {
|
struct ResourceInfo {
|
||||||
std::set<int> deviceIds;
|
std::set<int> deviceIds;
|
||||||
std::string aclConfigPath;
|
|
||||||
std::string singleOpFolderPath;
|
std::string singleOpFolderPath;
|
||||||
std::unordered_map<int, DeviceResInfo> deviceResInfos; // map <deviceId, deviceResourceInfo>
|
std::unordered_map<int, DeviceResInfo> deviceResInfos; // map <deviceId, deviceResourceInfo>
|
||||||
};
|
};
|
||||||
|
|
|
@ -1872,7 +1872,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
|
||||||
MS_LOG(WARNING) << "set attr value for const failed";
|
MS_LOG(WARNING) << "set attr value for const failed";
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (defined ENABLE_GE)
|
|
||||||
auto const_op = std::static_pointer_cast<Constant>(op);
|
auto const_op = std::static_pointer_cast<Constant>(op);
|
||||||
if (const_op == nullptr) {
|
if (const_op == nullptr) {
|
||||||
MS_LOG(ERROR) << "Get Constant operator failed";
|
MS_LOG(ERROR) << "Get Constant operator failed";
|
||||||
|
@ -1881,7 +1880,6 @@ OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
|
||||||
auto ge_tensor = const_op->get_attr_value();
|
auto ge_tensor = const_op->get_attr_value();
|
||||||
auto ge_desc = ge_tensor.GetTensorDesc();
|
auto ge_desc = ge_tensor.GetTensorDesc();
|
||||||
(void)const_op->update_output_desc_y(ge_desc);
|
(void)const_op->update_output_desc_y(ge_desc);
|
||||||
#endif
|
|
||||||
|
|
||||||
op_cache_[node.get()] = op;
|
op_cache_[node.get()] = op;
|
||||||
return op_cache_[node.get()];
|
return op_cache_[node.get()];
|
||||||
|
|
|
@ -317,11 +317,6 @@ size_t OpAdapterImpl::GetCustomOpOutputSize(const CusOperatorPtr &cus_op) {
|
||||||
|
|
||||||
std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
|
std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
|
||||||
const std::string &format) {
|
const std::string &format) {
|
||||||
if (shape_ptr == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Shape ptr is nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type == nullptr) {
|
if (type == nullptr) {
|
||||||
MS_LOG(ERROR) << "Type ptr is nullptr";
|
MS_LOG(ERROR) << "Type ptr is nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -331,8 +326,8 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::Sh
|
||||||
if (kObjectTypeTensorType == me_type) {
|
if (kObjectTypeTensorType == me_type) {
|
||||||
me_type = dyn_cast<TensorType>(type)->element()->type_id();
|
me_type = dyn_cast<TensorType>(type)->element()->type_id();
|
||||||
}
|
}
|
||||||
auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format);
|
|
||||||
return desc;
|
return TransformUtil::GetGeTensorDesc((shape_ptr == nullptr) ? ShapeVector{} : shape_ptr->shape(), me_type, format);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
|
Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
|
||||||
|
@ -472,7 +467,7 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_LOG(INFO) << "Op name is " << op->GetName();
|
MS_LOG(INFO) << "Op name is " << op->GetName() << " anf is " << node->DebugString();
|
||||||
|
|
||||||
auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
|
auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
|
||||||
auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);
|
auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);
|
||||||
|
|
|
@ -78,7 +78,7 @@ class MsBackend : public Backend {
|
||||||
~MsBackend() override = default;
|
~MsBackend() override = default;
|
||||||
|
|
||||||
LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = "");
|
LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = "");
|
||||||
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
|
virtual VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
|
||||||
|
|
||||||
VectorRef MsSimuRunGraph(const GraphId &g);
|
VectorRef MsSimuRunGraph(const GraphId &g);
|
||||||
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
||||||
|
@ -90,7 +90,7 @@ class MsBackend : public Backend {
|
||||||
void SetDebugger() override;
|
void SetDebugger() override;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
session::SessionPtr target_sess_;
|
session::SessionPtr target_sess_;
|
||||||
session::SessionPtr other_sess_;
|
session::SessionPtr other_sess_;
|
||||||
std::string target_device_;
|
std::string target_device_;
|
||||||
|
|
|
@ -145,6 +145,22 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CompileGraph::PushInputs(const FuncGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
std::vector<AnfNodePtr> parameters = graph->parameters();
|
||||||
|
for (size_t i = parameters.size(); i != 0; i--) {
|
||||||
|
MS_EXCEPTION_IF_NULL(parameters[i - 1]);
|
||||||
|
auto param = parameters[i - 1]->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
if (param->has_default()) {
|
||||||
|
MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Push(param);
|
||||||
|
MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
|
int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
|
||||||
MS_EXCEPTION_IF_NULL(segment);
|
MS_EXCEPTION_IF_NULL(segment);
|
||||||
MS_LOG(DEBUG) << "LinConvert start";
|
MS_LOG(DEBUG) << "LinConvert start";
|
||||||
|
@ -257,11 +273,16 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
|
InstSet CompileGraph::Run(const FuncGraphPtr &graph, bool push_weight) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
|
||||||
Reset();
|
Reset();
|
||||||
PushParameters(graph);
|
if (push_weight) {
|
||||||
|
PushParameters(graph);
|
||||||
|
} else {
|
||||||
|
PushInputs(graph);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t param_height = height_;
|
int64_t param_height = height_;
|
||||||
MS_EXCEPTION_IF_NULL(graph->get_return());
|
MS_EXCEPTION_IF_NULL(graph->get_return());
|
||||||
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
|
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
|
||||||
|
|
|
@ -53,14 +53,14 @@ class CompileGraph {
|
||||||
public:
|
public:
|
||||||
explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
|
explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
|
||||||
|
|
||||||
~CompileGraph() = default;
|
virtual ~CompileGraph() = default;
|
||||||
|
|
||||||
InstSet Run(const FuncGraphPtr &func_graph);
|
InstSet Run(const FuncGraphPtr &func_graph, bool push_weight = true);
|
||||||
bool IsCut(const AnfNodePtr &node);
|
bool IsCut(const AnfNodePtr &node);
|
||||||
void Push(const AnfNodePtr &node);
|
void Push(const AnfNodePtr &node);
|
||||||
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
|
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
|
||||||
void Ret(int64_t nargs);
|
void Ret(int64_t nargs);
|
||||||
int64_t Ref(const AnfNodePtr &node);
|
virtual int64_t Ref(const AnfNodePtr &node);
|
||||||
|
|
||||||
void set_height(int64_t h) {
|
void set_height(int64_t h) {
|
||||||
height_ = h;
|
height_ = h;
|
||||||
|
@ -76,8 +76,9 @@ class CompileGraph {
|
||||||
inst_.clear();
|
inst_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
void PushParameters(const FuncGraphPtr &func_graph);
|
void PushParameters(const FuncGraphPtr &func_graph);
|
||||||
|
void PushInputs(const FuncGraphPtr &graph);
|
||||||
bool Compile(const FuncGraphPtr &func_graph);
|
bool Compile(const FuncGraphPtr &func_graph);
|
||||||
int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
|
int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
|
||||||
int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
|
int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
|
||||||
|
@ -114,18 +115,18 @@ class CompileGraphs {
|
||||||
public:
|
public:
|
||||||
explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
|
explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
|
||||||
|
|
||||||
~CompileGraphs() = default;
|
virtual ~CompileGraphs() = default;
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
insts_.clear();
|
insts_.clear();
|
||||||
mapping_.clear();
|
mapping_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compile(const FuncGraphPtr &func_graph);
|
virtual void Compile(const FuncGraphPtr &func_graph);
|
||||||
FinalVMPtr Link();
|
FinalVMPtr Link();
|
||||||
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
|
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
InstSet insts_;
|
InstSet insts_;
|
||||||
std::unordered_map<FuncGraphPtr, int64_t> mapping_;
|
std::unordered_map<FuncGraphPtr, int64_t> mapping_;
|
||||||
CompileGraphPtr transform_;
|
CompileGraphPtr transform_;
|
||||||
|
|
|
@ -424,17 +424,11 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
||||||
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
|
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
|
||||||
|
|
||||||
auto out_shape = BroadcastShape(x_shape, y_shape);
|
auto out_shape = BroadcastShape(x_shape, y_shape);
|
||||||
if (out_shape.empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Less op BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
|
|
||||||
<< args_spec_list[1]->ToString();
|
|
||||||
}
|
|
||||||
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
|
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
|
||||||
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
|
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
|
||||||
|
|
||||||
auto output_type = std::make_shared<Bool>();
|
auto output_type = std::make_shared<Bool>();
|
||||||
auto ret =
|
return std::make_shared<AbstractTensor>(output_type,
|
||||||
std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
|
std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,503 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "load_mindir/infer_mindir.h"
|
||||||
|
#include <deque>
|
||||||
|
#include <set>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "abstract/abstract_function.h"
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace {
|
||||||
|
class MindIREngine {
|
||||||
|
public:
|
||||||
|
explicit MindIREngine(const FuncGraphPtr &root) : func_graph_(root), nodeuser_map_(root->manager()->node_users()) {}
|
||||||
|
~MindIREngine() = default;
|
||||||
|
MindIREngine(const MindIREngine &) = delete;
|
||||||
|
MindIREngine &operator=(const MindIREngine &) = delete;
|
||||||
|
|
||||||
|
bool InferShape(const AbstractBasePtrList &args);
|
||||||
|
|
||||||
|
private:
|
||||||
|
using AbstractBasePtrListPtr = std::shared_ptr<AbstractBasePtrList>;
|
||||||
|
|
||||||
|
void Init(const AbstractBasePtrList &args);
|
||||||
|
static AbstractBasePtr InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
|
||||||
|
void EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
|
||||||
|
void EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
|
||||||
|
void EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node);
|
||||||
|
void InferParameter(const AnfNodePtr &node);
|
||||||
|
void InferValueNode(const AnfNodePtr &node);
|
||||||
|
void InferCNode(const AnfNodePtr &node);
|
||||||
|
void EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &abstractFunc, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args);
|
||||||
|
void EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args);
|
||||||
|
void EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args);
|
||||||
|
void EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args);
|
||||||
|
bool CheckCNodeNotReady(const CNodePtr &node);
|
||||||
|
void UpdateReady(const AnfNodePtr &node);
|
||||||
|
void SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result);
|
||||||
|
AbstractBasePtr GetCNodeOperatorAbstract(const AnfNodePtr &node);
|
||||||
|
|
||||||
|
FuncGraphPtr func_graph_;
|
||||||
|
std::map<AnfNodePtr, int> node_input_depends_;
|
||||||
|
std::map<AnfNodePtr, AbstractBasePtr> infer_resut_;
|
||||||
|
std::map<std::string, AbstractBasePtr> func_graph_result_;
|
||||||
|
std::map<std::string, std::set<AnfNodePtr>> func_graph_visited_;
|
||||||
|
std::deque<AnfNodePtr> ready_;
|
||||||
|
std::set<AnfNodePtr> todo_;
|
||||||
|
NodeUsersMap nodeuser_map_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Infer the root function graph.
|
||||||
|
bool MindIREngine::InferShape(const AbstractBasePtrList &args) {
|
||||||
|
Init(args);
|
||||||
|
while (!ready_.empty()) {
|
||||||
|
auto current = ready_.front();
|
||||||
|
MS_EXCEPTION_IF_NULL(current);
|
||||||
|
ready_.pop_front();
|
||||||
|
if (current->isa<CNode>()) {
|
||||||
|
InferCNode(current);
|
||||||
|
} else if (current->isa<ValueNode>()) {
|
||||||
|
InferValueNode(current);
|
||||||
|
} else if (current->isa<Parameter>()) {
|
||||||
|
InferParameter(current);
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << " There is something changed. Please check the code.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set abstract of node.
|
||||||
|
for (const auto &item : infer_resut_) {
|
||||||
|
item.first->set_abstract(item.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (todo_.empty()) {
|
||||||
|
MS_LOG(DEBUG) << "Finish to Infere.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
MS_LOG(WARNING) << "Not finished to infer: " << todo_.size();
|
||||||
|
for (const auto &node : todo_) {
|
||||||
|
MS_LOG(DEBUG) << "Node uninfered: " << node->DebugString();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::Init(const AbstractBasePtrList &args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||||
|
auto manager = func_graph_->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
for (const auto &node : manager->all_nodes()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
todo_.insert(node);
|
||||||
|
node_input_depends_[node] = cnode->inputs().size();
|
||||||
|
} else if (node->isa<Parameter>()) {
|
||||||
|
auto param = node->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
if (param->has_default()) {
|
||||||
|
node_input_depends_[node] = 0;
|
||||||
|
infer_resut_[node] = param->default_param()->ToAbstract();
|
||||||
|
ready_.push_back(node);
|
||||||
|
} else {
|
||||||
|
node_input_depends_[node] = 1;
|
||||||
|
todo_.insert(node);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Value Node
|
||||||
|
node_input_depends_[node] = 0;
|
||||||
|
ready_.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputs = func_graph_->get_inputs();
|
||||||
|
if (inputs.size() != args.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The input parameters is not Compatible. mindir:" << inputs.size()
|
||||||
|
<< " inputs: " << args.size() << " FuncGraph:" << func_graph_->ToString();
|
||||||
|
}
|
||||||
|
// Root Func Parameters
|
||||||
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
|
this->SaveNodeInferResult(inputs[i], args[i]);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Finish init. Size of nodes:" << manager->all_nodes().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer primitive using C++ implement.
|
||||||
|
AbstractBasePtr MindIREngine::InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
try {
|
||||||
|
static auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
|
||||||
|
auto ret = prim_eval_implement_map.find(prim);
|
||||||
|
if (ret != prim_eval_implement_map.end()) {
|
||||||
|
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
|
||||||
|
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
|
||||||
|
return ret->second.infer_shape_impl_(nullptr, prim, args_spec_list);
|
||||||
|
} else {
|
||||||
|
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
|
||||||
|
static auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
|
||||||
|
auto ret_backend = prim_backend_eval_impl_map.find(prim);
|
||||||
|
if (ret_backend != prim_backend_eval_impl_map.end()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
|
||||||
|
return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(WARNING) << "Get infer shape function failed, primitive name:" << prim->name()
|
||||||
|
<< " primitive type:" << prim->type_name() << " It will keep the prevalue witch danger.";
|
||||||
|
} catch (const std::exception &ex) {
|
||||||
|
MS_LOG(WARNING) << "Catch primitive:" << prim->ToString() << " InferPrimitiveShape exception:" << ex.what()
|
||||||
|
<< " It will keep the prevalue witch danger.";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args) {
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
// Args has been resolved by partial
|
||||||
|
if (args != nullptr) {
|
||||||
|
args_spec_list.insert(args_spec_list.end(), args->begin(), args->end());
|
||||||
|
} else {
|
||||||
|
(void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(args_spec_list),
|
||||||
|
[this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call C++ infer
|
||||||
|
auto result = InferPrimitiveShape(prim, args_spec_list);
|
||||||
|
if (result == nullptr) {
|
||||||
|
MS_LOG(INFO) << node->ToString()
|
||||||
|
<< " can't be inferred shape. It will keep the prevalue witch danger. Prim: " << prim->ToString();
|
||||||
|
result = node->abstract();
|
||||||
|
}
|
||||||
|
SaveNodeInferResult(node, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node) {
|
||||||
|
if (node->inputs().size() < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
|
||||||
|
}
|
||||||
|
auto result = infer_resut_[node->inputs()[1]];
|
||||||
|
auto funcName = node->func_graph()->ToString();
|
||||||
|
auto it = func_graph_result_.find(funcName);
|
||||||
|
if (it != func_graph_result_.end()) {
|
||||||
|
result = result->Join(it->second);
|
||||||
|
}
|
||||||
|
this->func_graph_result_[funcName] = result;
|
||||||
|
SaveNodeInferResult(node, result);
|
||||||
|
MS_LOG(DEBUG) << funcName << " result: " << result->ToString();
|
||||||
|
|
||||||
|
// Set the result of the node whose Operator is this funcGraph
|
||||||
|
for (const auto &item : func_graph_visited_[funcName]) {
|
||||||
|
SaveNodeInferResult(item, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args) {
|
||||||
|
// Args has been resolved
|
||||||
|
if (args != nullptr) {
|
||||||
|
if (args->size() < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
|
||||||
|
}
|
||||||
|
auto real_func = (*args)[0]->cast<abstract::AbstractFuncAtomPtr>();
|
||||||
|
if (real_func == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << (*args)[0]->ToString() << " is not a function abstract.";
|
||||||
|
}
|
||||||
|
AbstractBasePtrList partial_args_list;
|
||||||
|
partial_args_list.insert(partial_args_list.end(), args->begin() + 1, args->end());
|
||||||
|
auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
|
||||||
|
SaveNodeInferResult(node, partial_func);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Not Resolved.
|
||||||
|
if (node->inputs().size() < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
|
||||||
|
}
|
||||||
|
auto &func = infer_resut_[node->inputs()[1]];
|
||||||
|
auto real_func = func->cast<abstract::AbstractFuncAtomPtr>();
|
||||||
|
if (real_func == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << func->ToString() << " is not a function abstract.";
|
||||||
|
}
|
||||||
|
AbstractBasePtrList partial_args_list;
|
||||||
|
(void)std::transform(node->inputs().begin() + 2, node->inputs().end(), std::back_inserter(partial_args_list),
|
||||||
|
[this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
|
||||||
|
auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
|
||||||
|
SaveNodeInferResult(node, partial_func);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args) {
|
||||||
|
AbstractBasePtrListPtr partial_args_list = std::make_shared<AbstractBasePtrList>();
|
||||||
|
// Join arguments in partial and the rest arguments from args_conf_list.
|
||||||
|
auto func_args = func->args();
|
||||||
|
partial_args_list->insert(partial_args_list->end(), func_args.begin(), func_args.end());
|
||||||
|
if (args == nullptr) {
|
||||||
|
// Not Recursive
|
||||||
|
(void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(*partial_args_list),
|
||||||
|
[this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
|
||||||
|
} else {
|
||||||
|
// Recursive
|
||||||
|
partial_args_list->insert(partial_args_list->end(), args->begin(), args->end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get real function
|
||||||
|
abstract::AbstractFuncAtomPtrList abstractFuncList;
|
||||||
|
auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
|
||||||
|
abstractFuncList.push_back(poss);
|
||||||
|
};
|
||||||
|
func->fn()->Visit(build_fuction);
|
||||||
|
for (const auto &abstractFunc : abstractFuncList) {
|
||||||
|
EvalAbstractFunction(abstractFunc, node, partial_args_list);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result) {
|
||||||
|
auto answer = result;
|
||||||
|
auto it = infer_resut_.find(node);
|
||||||
|
if (it != infer_resut_.end()) {
|
||||||
|
MS_LOG(DEBUG) << node->ToString() << " result: " << it->second->ToString();
|
||||||
|
answer = result->Join(it->second);
|
||||||
|
if (*answer == *(it->second)) {
|
||||||
|
MS_LOG(DEBUG) << node->ToString() << " The value is not changed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << node->ToString() << " result: " << answer->ToString();
|
||||||
|
infer_resut_[node] = answer;
|
||||||
|
UpdateReady(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args) {
|
||||||
|
auto prim = func->prim();
|
||||||
|
// Return Primitive
|
||||||
|
if (prim->name() == prim::kPrimReturn->name()) {
|
||||||
|
EvalReturnPrimitive(prim, node);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Partial Primitive
|
||||||
|
if (prim->name() == prim::kPrimPartial->name()) {
|
||||||
|
EvalPartialPrimitive(prim, node, args);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// common Primitive
|
||||||
|
EvalCommonPrimitive(prim, node, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MindIREngine::CheckCNodeNotReady(const CNodePtr &node) {
|
||||||
|
int depend = 0;
|
||||||
|
for (const auto &input : node->inputs()) {
|
||||||
|
depend += infer_resut_.find(input) != infer_resut_.end() ? 0 : 1;
|
||||||
|
}
|
||||||
|
this->node_input_depends_[node] = depend;
|
||||||
|
return depend;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(func);
|
||||||
|
MS_EXCEPTION_IF_NULL(func->func_graph());
|
||||||
|
// Has Processd
|
||||||
|
MS_LOG(DEBUG) << node->ToString() << " FuncGraph: " << func->ToString();
|
||||||
|
auto funcName = func->func_graph()->ToString();
|
||||||
|
auto it = func_graph_result_.find(funcName);
|
||||||
|
if (it != func_graph_result_.end()) {
|
||||||
|
MS_LOG(DEBUG) << "The abstract of " << node->ToString() << " = abstract of " << func->ToString();
|
||||||
|
SaveNodeInferResult(node, it->second);
|
||||||
|
|
||||||
|
// Process only one return valueNode function graph
|
||||||
|
auto func_inputs = func->func_graph()->get_inputs();
|
||||||
|
// args has been resolved in partial.
|
||||||
|
if (args != nullptr) {
|
||||||
|
if (func_inputs.size() != args->size()) {
|
||||||
|
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
|
||||||
|
<< " CNode:" << node->DebugString() << " input size:" << args->size();
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < func_inputs.size(); ++i) {
|
||||||
|
infer_resut_[func_inputs[i]] =
|
||||||
|
(*args)[i]; // Not use SaveNodeInferResult because this function has been evaluated.
|
||||||
|
todo_.erase(func_inputs[i]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// args is not resolved.
|
||||||
|
auto &cnode_inputs = node->inputs();
|
||||||
|
if (func_inputs.size() != cnode_inputs.size() - 1) {
|
||||||
|
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
|
||||||
|
<< " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < func_inputs.size(); ++i) {
|
||||||
|
infer_resut_[func_inputs[i]] = infer_resut_[cnode_inputs[i + 1]];
|
||||||
|
todo_.erase(func_inputs[i]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Be handling
|
||||||
|
auto visitIt = func_graph_visited_.find(funcName);
|
||||||
|
if (visitIt != func_graph_visited_.end()) {
|
||||||
|
visitIt->second.insert(node);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
func_graph_visited_[funcName] = std::set<AnfNodePtr>({node});
|
||||||
|
|
||||||
|
// Call the funcGraph
|
||||||
|
auto func_inputs = func->func_graph()->get_inputs();
|
||||||
|
|
||||||
|
// args has been resolved in partial.
|
||||||
|
if (args != nullptr) {
|
||||||
|
if (func_inputs.size() != args->size()) {
|
||||||
|
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
|
||||||
|
<< " CNode:" << node->DebugString() << " input size:" << args->size();
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < func_inputs.size(); ++i) {
|
||||||
|
SaveNodeInferResult(func_inputs[i], (*args)[i]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// args is not resolved.
|
||||||
|
auto &cnode_inputs = node->inputs();
|
||||||
|
if (func_inputs.size() != cnode_inputs.size() - 1) {
|
||||||
|
MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
|
||||||
|
<< " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < func_inputs.size(); ++i) {
|
||||||
|
SaveNodeInferResult(func_inputs[i], infer_resut_[cnode_inputs[i + 1]]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::InferParameter(const AnfNodePtr &node) { UpdateReady(node); }
|
||||||
|
|
||||||
|
void MindIREngine::InferValueNode(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto value_node = node->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto value = GetValueNode(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
AbstractBasePtr result;
|
||||||
|
if (value->isa<FuncGraph>()) {
|
||||||
|
auto func_graph = value->cast<FuncGraphPtr>();
|
||||||
|
auto temp_context = abstract::AnalysisContext::DummyContext();
|
||||||
|
result = std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, node);
|
||||||
|
} else if (value->isa<Primitive>()) {
|
||||||
|
auto prim = value->cast<PrimitivePtr>();
|
||||||
|
result = std::make_shared<abstract::PrimitiveAbstractClosure>(prim, node);
|
||||||
|
} else {
|
||||||
|
result = value->ToAbstract();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result->isa<abstract::AbstractTensor>()) {
|
||||||
|
result = result->Broaden();
|
||||||
|
}
|
||||||
|
SaveNodeInferResult(node, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr MindIREngine::GetCNodeOperatorAbstract(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto op = cnode->inputs()[0];
|
||||||
|
auto it = infer_resut_.find(op);
|
||||||
|
if (it != infer_resut_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Can't get the abstract of Node:" << op->DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If args is nullPtr, it is called by InferCNode, else it is called recursively by EvalPartialAbastract.
|
||||||
|
void MindIREngine::EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node,
|
||||||
|
const AbstractBasePtrListPtr &args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func);
|
||||||
|
if (func->isa<abstract::PrimitiveAbstractClosure>()) {
|
||||||
|
// C++ Primitive
|
||||||
|
auto prim = func->cast<abstract::PrimitiveAbstractClosurePtr>();
|
||||||
|
EvalPrimitiveAbastract(prim, node, args);
|
||||||
|
} else if (func->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||||
|
// FuncGraph
|
||||||
|
auto funcGraph = func->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||||
|
EvalFuncGraphAbastract(funcGraph, node, args);
|
||||||
|
} else if (func->isa<abstract::PartialAbstractClosure>()) {
|
||||||
|
// Partial
|
||||||
|
auto partialPrim = func->cast<abstract::PartialAbstractClosurePtr>();
|
||||||
|
EvalPartialAbastract(partialPrim, node, args);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "MindIR can't process the abstractFunc: " << func->DumpText();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::UpdateReady(const AnfNodePtr &node) {
|
||||||
|
todo_.erase(node);
|
||||||
|
auto it = nodeuser_map_.find(node);
|
||||||
|
if (it == nodeuser_map_.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto &users = it->second;
|
||||||
|
MS_LOG(DEBUG) << node->ToString() << " has users: " << users.size();
|
||||||
|
for (const auto &user : users) {
|
||||||
|
int count = node_input_depends_[user.first];
|
||||||
|
node_input_depends_[user.first] = count - 1;
|
||||||
|
if (count <= 1) {
|
||||||
|
ready_.push_back(user.first);
|
||||||
|
MS_LOG(DEBUG) << "Node:" << user.first->ToString() << " is ready.";
|
||||||
|
if (count < 1) {
|
||||||
|
MS_LOG(INFO) << " There is something to do. Node:" << node->ToString() << " user:" << user.first->DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MindIREngine::InferCNode(const AnfNodePtr &node) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (CheckCNodeNotReady(cnode)) {
|
||||||
|
MS_LOG(INFO) << "The node is not ready: " << cnode->DebugString();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode);
|
||||||
|
if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||||
|
MS_LOG(EXCEPTION) << "EvalCNode eval Undetermined";
|
||||||
|
}
|
||||||
|
abstract::AbstractFunctionPtr func = dyn_cast<abstract::AbstractFunction>(possible_func);
|
||||||
|
if (func == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
|
||||||
|
MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
|
||||||
|
}
|
||||||
|
abstract::AbstractFuncAtomPtrList abstractFuncList;
|
||||||
|
auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
|
||||||
|
abstractFuncList.push_back(poss);
|
||||||
|
};
|
||||||
|
func->Visit(build_fuction);
|
||||||
|
for (const auto &abstractFunc : abstractFuncList) {
|
||||||
|
EvalAbstractFunction(abstractFunc, cnode, nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args) {
|
||||||
|
auto engine = std::make_shared<MindIREngine>(root);
|
||||||
|
return engine->InferShape(args);
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,26 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H
|
||||||
|
#define MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H
|
||||||
|
#include "base/base.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args);
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_LOAD_MINDIR_INFER_MINDIR_H
|
|
@ -51,6 +51,7 @@ const char kCPUDevice[] = "CPU";
|
||||||
const char kGPUDevice[] = "GPU";
|
const char kGPUDevice[] = "GPU";
|
||||||
const char kAscendDevice[] = "Ascend";
|
const char kAscendDevice[] = "Ascend";
|
||||||
const char kDavinciInferenceDevice[] = "AscendInference";
|
const char kDavinciInferenceDevice[] = "AscendInference";
|
||||||
|
const char kDavinciMultiGraphInferenceDevice[] = "AscendMultiGraphInference";
|
||||||
const char kGpuInferenceDevice[] = "GpuInference";
|
const char kGpuInferenceDevice[] = "GpuInference";
|
||||||
const char kDavinciDevice[] = "Davinci";
|
const char kDavinciDevice[] = "Davinci";
|
||||||
const char KNpuLog[] = "_npu_log";
|
const char KNpuLog[] = "_npu_log";
|
||||||
|
|
|
@ -36,7 +36,6 @@ constexpr auto kModelOptionProvider = "mindspore.option.provider";
|
||||||
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
||||||
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
||||||
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
||||||
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
|
|
||||||
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
|
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
|
||||||
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
|
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
|
||||||
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
|
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
|
||||||
|
@ -330,23 +329,6 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||||
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
|
|
||||||
if (data_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
|
|
||||||
if (data_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
|
||||||
return std::vector<char>();
|
|
||||||
}
|
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
|
|
||||||
return StringToChar(ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
|
|
@ -0,0 +1,405 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
|
||||||
|
using namespace mindspore;
|
||||||
|
|
||||||
|
static constexpr char kIfbyIfFile[] = "/home/workspace/mindspore_dataset/mindir/control/ifbyif.mindir";
|
||||||
|
static constexpr char kSimpleWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/simple_while.mindir";
|
||||||
|
static constexpr char kMixIfWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/mix_while_if.mindir";
|
||||||
|
static constexpr char kRecursiveFile[] = "/home/workspace/mindspore_dataset/mindir/control/fibonacci.mindir";
|
||||||
|
static constexpr char kSingleForFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_for.mindir";
|
||||||
|
static constexpr char kSingleOrFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_or.mindir";
|
||||||
|
static constexpr char kSingleSwitchFile[] = "/home/workspace/mindspore_dataset/mindir/control/switch_layer_net.mindir";
|
||||||
|
static constexpr float kConstValue = 0.1234;
|
||||||
|
static const std::vector<float> input_data(2 * 3 * 4 * 5, kConstValue);
|
||||||
|
|
||||||
|
class TestControl : public ST::Common {
|
||||||
|
public:
|
||||||
|
TestControl() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestControl, InferIfbyIf) {
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(kIfbyIfFile, ModelType::kMindIR, &graph));
|
||||||
|
Model control_model;
|
||||||
|
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
|
// assert inputs
|
||||||
|
std::vector<MSTensor> inputs_before = control_model.GetInputs();
|
||||||
|
ASSERT_EQ(5, inputs_before.size());
|
||||||
|
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
|
||||||
|
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
|
||||||
|
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeBool);
|
||||||
|
EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeBool);
|
||||||
|
EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeFloat32);
|
||||||
|
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float));
|
||||||
|
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float));
|
||||||
|
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(bool));
|
||||||
|
ASSERT_EQ(inputs_before[3].DataSize(), sizeof(bool));
|
||||||
|
ASSERT_EQ(inputs_before[4].DataSize(), sizeof(float) * input_data.size());
|
||||||
|
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[3].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[3].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[4].Shape().size(), 4);
|
||||||
|
EXPECT_EQ(inputs_before[4].Shape()[0], 2);
|
||||||
|
EXPECT_EQ(inputs_before[4].Shape()[1], 3);
|
||||||
|
EXPECT_EQ(inputs_before[4].Shape()[2], 4);
|
||||||
|
EXPECT_EQ(inputs_before[4].Shape()[3], 5);
|
||||||
|
|
||||||
|
// prepare input
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
|
||||||
|
float x = 2.345678, y = 1.234567;
|
||||||
|
bool cond1 = true, cond2 = false;
|
||||||
|
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
|
||||||
|
sizeof(float));
|
||||||
|
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
|
||||||
|
sizeof(float));
|
||||||
|
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &cond1,
|
||||||
|
sizeof(bool));
|
||||||
|
inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &cond2,
|
||||||
|
sizeof(bool));
|
||||||
|
inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), input_data.data(),
|
||||||
|
sizeof(float) * input_data.size());
|
||||||
|
|
||||||
|
// infer
|
||||||
|
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
|
// assert output
|
||||||
|
ASSERT_TRUE(outputs.size() == 1);
|
||||||
|
auto out = outputs[0];
|
||||||
|
ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
|
||||||
|
auto out_data = out.Data();
|
||||||
|
auto p = reinterpret_cast<const float *>(out_data.get());
|
||||||
|
for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
|
||||||
|
ASSERT_LE(std::abs(p[i] - kConstValue * 24), 1e-3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestControl, InferSimpleWhile) {
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(kSimpleWhileFile, ModelType::kMindIR, &graph));
|
||||||
|
Model control_model;
|
||||||
|
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
|
// assert inputs
|
||||||
|
std::vector<MSTensor> inputs_before = control_model.GetInputs();
|
||||||
|
ASSERT_EQ(3, inputs_before.size());
|
||||||
|
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeBool);
|
||||||
|
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeBool);
|
||||||
|
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeFloat32);
|
||||||
|
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(bool));
|
||||||
|
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(bool));
|
||||||
|
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(float) * input_data.size());
|
||||||
|
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[2].Shape().size(), 4);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[0], 2);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[1], 3);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[2], 4);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[3], 5);
|
||||||
|
|
||||||
|
// prepare input
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
{
|
||||||
|
bool x = true, y = false;
|
||||||
|
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
|
||||||
|
sizeof(bool));
|
||||||
|
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
|
||||||
|
sizeof(bool));
|
||||||
|
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(),
|
||||||
|
input_data.data(), sizeof(float) * input_data.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// infer
|
||||||
|
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
|
// assert output
|
||||||
|
ASSERT_TRUE(outputs.size() == 1);
|
||||||
|
auto out = outputs[0];
|
||||||
|
ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
|
||||||
|
auto out_data = out.Data();
|
||||||
|
auto p = reinterpret_cast<const float *>(out_data.get());
|
||||||
|
for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
|
||||||
|
ASSERT_LE(std::abs(p[i] - kConstValue * 3), 1e-3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestControl, InferRecursive) {
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(kRecursiveFile, ModelType::kMindIR, &graph));
|
||||||
|
Model control_model;
|
||||||
|
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
|
// assert inputs
|
||||||
|
std::vector<MSTensor> inputs_before = control_model.GetInputs();
|
||||||
|
ASSERT_EQ(1, inputs_before.size());
|
||||||
|
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
|
||||||
|
|
||||||
|
// prepare input
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
{
|
||||||
|
int32_t x = 7;
|
||||||
|
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
|
||||||
|
sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// infer
|
||||||
|
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
|
// assert output
|
||||||
|
ASSERT_TRUE(outputs.size() == 1);
|
||||||
|
auto out = outputs[0];
|
||||||
|
ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
|
||||||
|
auto out_data = out.Data();
|
||||||
|
auto p = reinterpret_cast<const int32_t *>(out_data.get());
|
||||||
|
ASSERT_EQ(*p, 21);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestControl, InferMixedWhileIf) {
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(kMixIfWhileFile, ModelType::kMindIR, &graph));
|
||||||
|
Model control_model;
|
||||||
|
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
|
// assert inputs
|
||||||
|
std::vector<MSTensor> inputs_before = control_model.GetInputs();
|
||||||
|
ASSERT_EQ(inputs_before.size(), 5);
|
||||||
|
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[3].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[4].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[3].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[3].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[4].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[4].Shape()[0], 1);
|
||||||
|
|
||||||
|
// prepare input
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
{
|
||||||
|
int32_t x = 2, y = 14, z = 1, c2 = 14, c4 = 0;
|
||||||
|
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
|
||||||
|
sizeof(int32_t));
|
||||||
|
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
|
||||||
|
sizeof(int32_t));
|
||||||
|
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
|
||||||
|
sizeof(int32_t));
|
||||||
|
inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &c2,
|
||||||
|
sizeof(int32_t));
|
||||||
|
inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), &c4,
|
||||||
|
sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// infer
|
||||||
|
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
|
// assert output
|
||||||
|
ASSERT_TRUE(outputs.size() == 1);
|
||||||
|
auto out = outputs[0];
|
||||||
|
ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
|
||||||
|
auto out_data = out.Data();
|
||||||
|
auto p = reinterpret_cast<const int32_t *>(out_data.get());
|
||||||
|
ASSERT_EQ(*p, 350);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestControl, InferSingleFor) {
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(kSingleForFile, ModelType::kMindIR, &graph));
|
||||||
|
Model control_model;
|
||||||
|
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
|
// assert inputs
|
||||||
|
std::vector<MSTensor> inputs_before = control_model.GetInputs();
|
||||||
|
ASSERT_EQ(inputs_before.size(), 3);
|
||||||
|
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
|
||||||
|
|
||||||
|
// prepare input
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
{
|
||||||
|
int32_t x = 2, y = 5, z = 4;
|
||||||
|
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
|
||||||
|
sizeof(int32_t));
|
||||||
|
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
|
||||||
|
sizeof(int32_t));
|
||||||
|
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
|
||||||
|
sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// infer
|
||||||
|
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
|
// assert output
|
||||||
|
ASSERT_TRUE(outputs.size() == 1);
|
||||||
|
auto out = outputs[0];
|
||||||
|
ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
|
||||||
|
auto out_data = out.Data();
|
||||||
|
auto p = reinterpret_cast<const int32_t *>(out_data.get());
|
||||||
|
ASSERT_EQ(*p, 125);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestControl, InferSingleOr) {
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(kSingleOrFile, ModelType::kMindIR, &graph));
|
||||||
|
Model control_model;
|
||||||
|
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
|
// assert inputs
|
||||||
|
std::vector<MSTensor> inputs_before = control_model.GetInputs();
|
||||||
|
ASSERT_EQ(inputs_before.size(), 2);
|
||||||
|
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
|
||||||
|
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
|
||||||
|
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 2);
|
||||||
|
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float) * 2);
|
||||||
|
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[0], 2);
|
||||||
|
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[1].Shape()[0], 2);
|
||||||
|
|
||||||
|
// prepare input
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
{
|
||||||
|
static const std::vector<float> input_data1 = {0, 1};
|
||||||
|
static const std::vector<float> input_data2 = {0, 0};
|
||||||
|
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
|
||||||
|
input_data1.data(), sizeof(float) * input_data1.size());
|
||||||
|
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(),
|
||||||
|
input_data2.data(), sizeof(int32_t) * input_data2.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// infer
|
||||||
|
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
|
// assert output
|
||||||
|
ASSERT_TRUE(outputs.size() == 1);
|
||||||
|
auto out = outputs[0];
|
||||||
|
ASSERT_TRUE(out.DataSize() == sizeof(float));
|
||||||
|
auto out_data = out.Data();
|
||||||
|
auto p = reinterpret_cast<const float *>(out_data.get());
|
||||||
|
ASSERT_EQ(*p, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestControl, InferSingleSwitch) {
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(kSingleSwitchFile, ModelType::kMindIR, &graph));
|
||||||
|
Model control_model;
|
||||||
|
ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
|
// assert inputs
|
||||||
|
std::vector<MSTensor> inputs_before = control_model.GetInputs();
|
||||||
|
ASSERT_EQ(inputs_before.size(), 3);
|
||||||
|
EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
|
||||||
|
EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
|
||||||
|
ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 224 * 224);
|
||||||
|
ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
|
||||||
|
ASSERT_EQ(inputs_before[0].Shape().size(), 4);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[1], 1);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[2], 224);
|
||||||
|
EXPECT_EQ(inputs_before[0].Shape()[3], 224);
|
||||||
|
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[1].Shape()[0], 1);
|
||||||
|
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
|
||||||
|
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
|
||||||
|
|
||||||
|
// prepare input
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
{
|
||||||
|
static const std::vector<float> input_data1(1 * 1 * 224 * 224, 1);
|
||||||
|
int32_t index1 = 0;
|
||||||
|
int32_t index2 = -1;
|
||||||
|
inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
|
||||||
|
input_data1.data(), sizeof(float) * input_data1.size());
|
||||||
|
inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &index1,
|
||||||
|
sizeof(int32_t));
|
||||||
|
inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &index2,
|
||||||
|
sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// infer
|
||||||
|
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
|
// assert output
|
||||||
|
ASSERT_TRUE(outputs.size() == 1);
|
||||||
|
auto out = outputs[0];
|
||||||
|
ASSERT_TRUE(out.DataSize() == sizeof(float) * 224 * 224);
|
||||||
|
auto out_data = out.Data();
|
||||||
|
auto p = reinterpret_cast<const float *>(out_data.get());
|
||||||
|
for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
|
||||||
|
ASSERT_EQ(p[i], 1);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue