forked from mindspore-Ecosystem/mindspore
!23109 bugfix: get real output of graph and code review
Merge pull request !23109 from zhengyuanhua/master
This commit is contained in:
commit
3232ad7f75
|
@ -21,13 +21,16 @@ if(ENABLE_D OR ENABLE_ACL)
|
|||
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
|
||||
|
||||
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"model/acl/*.cc"
|
||||
"model/acl/acl_model_options.cc"
|
||||
"model/acl/model_converter.cc"
|
||||
"model/model_converter_utils/*.cc"
|
||||
"graph/acl/*.cc"
|
||||
)
|
||||
|
||||
if(NOT(BUILD_LITE))
|
||||
list(APPEND API_ACL_SRC "akg_kernel_register.cc")
|
||||
list(APPEND API_ACL_SRC "${CMAKE_CURRENT_SOURCE_DIR}/akg_kernel_register.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/model/acl/acl_model_multi.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/model/acl/acl_model.cc")
|
||||
endif()
|
||||
|
||||
if(NOT ENABLE_D)
|
||||
|
@ -50,17 +53,20 @@ endif()
|
|||
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model/model_impl.cc
|
||||
${API_MS_INFER_SRC}
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_MINDIR_SRC}
|
||||
${MS_UTILS_SRC})
|
||||
|
||||
if(NOT(BUILD_LITE))
|
||||
list(APPEND MSLIB_SRC "${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/model/model_impl.cc")
|
||||
endif()
|
||||
|
||||
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
|
||||
if(NOT(BUILD_LITE))
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore)
|
||||
|
|
|
@ -22,8 +22,8 @@
|
|||
namespace mindspore::kernel {
|
||||
namespace acl {
|
||||
typedef struct AclModelOptions {
|
||||
int32_t device_id_;
|
||||
std::string dump_cfg_path_;
|
||||
int32_t device_id;
|
||||
std::string dump_cfg_path;
|
||||
} AclModelOptions;
|
||||
|
||||
} // namespace acl
|
||||
|
|
|
@ -40,7 +40,7 @@ CustomAscend310Kernel::~CustomAscend310Kernel() {
|
|||
|
||||
AclModelOptions CustomAscend310Kernel::GetAclModelOptions(const mindspore::Context *ctx) const {
|
||||
AclModelOptions options;
|
||||
options.device_id_ = 0;
|
||||
options.device_id = 0;
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(WARNING) << "Context is nullptr.";
|
||||
return options;
|
||||
|
@ -61,8 +61,8 @@ AclModelOptions CustomAscend310Kernel::GetAclModelOptions(const mindspore::Conte
|
|||
return options;
|
||||
}
|
||||
|
||||
options.device_id_ = static_cast<int32_t>(ascend31o_info->GetDeviceID());
|
||||
options.dump_cfg_path_ = ascend31o_info->GetDumpConfigPath();
|
||||
options.device_id = static_cast<int32_t>(ascend31o_info->GetDeviceID());
|
||||
options.dump_cfg_path = ascend31o_info->GetDumpConfigPath();
|
||||
return options;
|
||||
}
|
||||
|
||||
|
|
|
@ -36,12 +36,12 @@ STATUS ModelInfer::Init() {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
acl_env_ = AclEnvGuard::GetAclEnv(options_.dump_cfg_path_);
|
||||
acl_env_ = AclEnvGuard::GetAclEnv(options_.dump_cfg_path);
|
||||
if (acl_env_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Acl init failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
int32_t device_id = options_.device_id_;
|
||||
int32_t device_id = options_.device_id;
|
||||
aclError ret = aclrtSetDevice(device_id);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl open device " << device_id << " failed.";
|
||||
|
@ -98,11 +98,11 @@ STATUS ModelInfer::Finalize() {
|
|||
}
|
||||
MS_LOG(INFO) << "End to destroy context.";
|
||||
|
||||
rt_ret = aclrtResetDevice(options_.device_id_);
|
||||
rt_ret = aclrtResetDevice(options_.device_id);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Reset device " << options_.device_id_ << " failed.";
|
||||
MS_LOG(ERROR) << "Reset device " << options_.device_id << " failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "End to reset device " << options_.device_id_;
|
||||
MS_LOG(INFO) << "End to reset device " << options_.device_id;
|
||||
init_flag_ = false;
|
||||
load_flag_ = false;
|
||||
return lite::RET_OK;
|
||||
|
|
|
@ -37,7 +37,10 @@ namespace {
|
|||
constexpr auto kMakeTuple = "MakeTuple";
|
||||
constexpr auto kOutputNames = "outputs_names";
|
||||
constexpr auto kCustomPrimTypeACL = "ACL";
|
||||
constexpr auto kCustomNodeName = "Custom_0";
|
||||
constexpr auto kCustomNodeName = "custom_0";
|
||||
constexpr size_t kDependInputNum = 3;
|
||||
constexpr size_t kDependFirstInputIdx = 1;
|
||||
constexpr size_t kTupleGetItemFirstInputIdx = 1;
|
||||
} // namespace
|
||||
|
||||
ParameterPtr AclPass::CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om_data) {
|
||||
|
@ -256,58 +259,76 @@ void AclPass::SetAclModelOptions(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(INFO) << "Set acl model options end.";
|
||||
}
|
||||
|
||||
STATUS AclPass::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph, AnfNodePtrList *graph_outputs,
|
||||
std::vector<std::string> *graph_output_names,
|
||||
std::vector<std::vector<int64_t>> *graph_output_dims) {
|
||||
CHECK_NULL_RETURN(graph_outputs);
|
||||
CHECK_NULL_RETURN(graph_output_names);
|
||||
CHECK_NULL_RETURN(graph_output_dims);
|
||||
AnfNodePtr return_input = func_graph->output();
|
||||
CHECK_NULL_RETURN(return_input);
|
||||
auto input_cnode = return_input->cast<CNodePtr>();
|
||||
CHECK_NULL_RETURN(input_cnode);
|
||||
auto primitive = mindspore::GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive is nullptr, node: " << input_cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
STATUS AclPass::TraceOutput(const AnfNodePtr &node) {
|
||||
static size_t iter = 0;
|
||||
CHECK_NULL_RETURN(node);
|
||||
AnfNodePtr cur_node = node;
|
||||
AnfNodePtr pre_node = nullptr;
|
||||
while (cur_node->isa<CNode>() && IsPrimitiveCNode(cur_node, prim::kPrimTupleGetItem)) {
|
||||
pre_node = cur_node;
|
||||
auto tmp = cur_node->cast<CNodePtr>();
|
||||
CHECK_NULL_RETURN(tmp);
|
||||
cur_node = tmp->input(kTupleGetItemFirstInputIdx);
|
||||
}
|
||||
// not consider custom op
|
||||
std::string primitive_type = primitive->name();
|
||||
if (primitive_type == kMakeTuple) {
|
||||
for (size_t j = 1; j < input_cnode->inputs().size(); j++) {
|
||||
auto item = input_cnode->input(j);
|
||||
MS_ASSERT(item != nullptr);
|
||||
graph_outputs->emplace_back(item);
|
||||
graph_output_names->emplace_back(item->fullname_with_scope());
|
||||
auto item_cnode = item->cast<CNodePtr>();
|
||||
if (item_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Input of MakeTuple is not a cnode for input_id: " << j;
|
||||
auto cnode = cur_node->cast<CNodePtr>();
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
std::string name = lite::acl::GetCNodeTargetFuncName(cnode);
|
||||
iter++;
|
||||
MS_LOG(INFO) << "Func name of cnode " << name << " ,trace iter: " << iter;
|
||||
if (name == kMakeTuple) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
if (TraceOutput(cnode->input(i)) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "The input[ " << i << "]"
|
||||
<< " trace output failed, name: " << name;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<int64_t> dims;
|
||||
if (lite::acl::GetShapeVectorFromCNode(item_cnode, &dims) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get node shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
graph_output_dims->emplace_back(dims);
|
||||
}
|
||||
} else if (name == prim::kPrimDepend->name()) {
|
||||
if (cnode->inputs().size() < kDependInputNum) {
|
||||
MS_LOG(ERROR) << "Length of inputs is " << cnode->inputs().size() << ", which is less than three.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (TraceOutput(cnode->input(kDependFirstInputIdx)) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Depend node trace output failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
graph_outputs->emplace_back(input_cnode);
|
||||
graph_output_names->emplace_back(input_cnode->fullname_with_scope());
|
||||
MS_LOG(INFO) << "Graph out name: " << cnode->fullname_with_scope();
|
||||
graph_output_names_.emplace_back(cnode->fullname_with_scope());
|
||||
if (pre_node != nullptr && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) {
|
||||
cnode = pre_node->cast<CNodePtr>();
|
||||
}
|
||||
std::vector<int64_t> dims;
|
||||
if (lite::acl::GetShapeVectorFromCNode(input_cnode, &dims) != lite::RET_OK) {
|
||||
if (lite::acl::GetShapeVectorFromCNode(cnode, &dims) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get node shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
graph_output_dims->emplace_back(dims);
|
||||
graph_output_dims_.emplace_back(dims);
|
||||
graph_outputs_.emplace_back(cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph) {
|
||||
AnfNodePtr return_input = func_graph->output();
|
||||
CHECK_NULL_RETURN(return_input);
|
||||
if (TraceOutput(return_input) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Trace output failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (graph_outputs_.empty() || graph_outputs_.size() != graph_output_dims_.size()) {
|
||||
MS_LOG(ERROR) << "Graph output size is error, num size: " << graph_outputs_.size()
|
||||
<< " dim size: " << graph_output_dims_.size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type) {
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t j = 0; j < graph_outputs_.size(); j++) {
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_outputs_dims_[j], data_type);
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[j], data_type);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract tensor is nullptr for output " << j;
|
||||
return lite::RET_ERROR;
|
||||
|
@ -319,21 +340,16 @@ STATUS AclPass::SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type) {
|
|||
}
|
||||
|
||||
STATUS AclPass::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
|
||||
STATUS ret = GetFuncGraphOutputInfo(func_graph, &graph_outputs_, &graph_output_names_, &graph_outputs_dims_);
|
||||
STATUS ret = GetFuncGraphOutputInfo(func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get output info of graph failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (graph_outputs_.empty() || graph_outputs_.size() != graph_outputs_dims_.size()) {
|
||||
MS_LOG(ERROR) << "Graph output size is error, num size: " << graph_outputs_.size()
|
||||
<< " dim size: " << graph_outputs_dims_.size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
custom_node->AddAttr(kOutputNames, MakeValue(graph_output_names_));
|
||||
|
||||
TypeId type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
|
||||
if (graph_outputs_.size() == 1) {
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_outputs_dims_[0], type);
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[0], type);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract_tensor is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -52,9 +52,8 @@ class AclPass : public Pass {
|
|||
STATUS ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
|
||||
const CNodePtr &custom_node);
|
||||
void SetAclModelOptions(const FuncGraphPtr &func_graph);
|
||||
STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph, AnfNodePtrList *graph_outputs,
|
||||
std::vector<std::string> *graph_output_names,
|
||||
std::vector<std::vector<int64_t>> *graph_output_dims);
|
||||
STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph);
|
||||
STATUS TraceOutput(const AnfNodePtr &node);
|
||||
|
||||
FmkType fmk_type_;
|
||||
ParameterPtr om_parameter_ = nullptr;
|
||||
|
@ -62,7 +61,7 @@ class AclPass : public Pass {
|
|||
std::unique_ptr<AclModelOptions> options_;
|
||||
AnfNodePtrList graph_outputs_;
|
||||
std::vector<std::string> graph_output_names_;
|
||||
std::vector<std::vector<int64_t>> graph_outputs_dims_;
|
||||
std::vector<std::vector<int64_t>> graph_output_dims_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -156,6 +156,29 @@ std::vector<int> GetIntParameterData(const ParameterPtr ¶m_ptr) {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool IsCaseNode(const CNodePtr node) {
|
||||
if (node->input(0) == nullptr) {
|
||||
MS_LOG(WARNING) << "The input of node is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
|
||||
GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string GetCNodeTargetFuncName(const CNodePtr &cnode) {
|
||||
if (IsCaseNode(cnode)) {
|
||||
return string("Case");
|
||||
}
|
||||
auto name = GetCNodeFuncName(cnode);
|
||||
if (name == "switch_layer") {
|
||||
name = "";
|
||||
}
|
||||
return name;
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ACL_COMMON_UTILS_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
@ -30,6 +31,8 @@ STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int
|
|||
TypeId GetTypeFromNode(const AnfNodePtr &node);
|
||||
|
||||
std::vector<int> GetIntParameterData(const ParameterPtr ¶m_ptr);
|
||||
|
||||
std::string GetCNodeTargetFuncName(const CNodePtr &cnode);
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -309,6 +309,11 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter:
|
|||
CHECK_NULL_RETURN(optimizer);
|
||||
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
||||
CHECK_NULL_RETURN(convert_pm);
|
||||
auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel);
|
||||
CHECK_NULL_RETURN(infershape_pass);
|
||||
convert_pm->AddPass(infershape_pass);
|
||||
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
|
||||
convert_pm->AddPass(update_conv2d_param_pass);
|
||||
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
|
||||
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel));
|
||||
optimizer->AddPassManager(convert_pm);
|
||||
|
@ -329,11 +334,6 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
|
|||
if (!config->trainModel) {
|
||||
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
|
||||
}
|
||||
auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel);
|
||||
CHECK_NULL_RETURN(infershape_pass);
|
||||
const_fold_pm->AddPass(infershape_pass);
|
||||
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
|
||||
const_fold_pm->AddPass(update_conv2d_param_pass);
|
||||
optimizer->AddPassManager(const_fold_pm);
|
||||
if (optimizer->Optimize(old_graph) == nullptr) {
|
||||
MS_LOG(ERROR) << "run const fold failed.";
|
||||
|
|
Loading…
Reference in New Issue