fix tool converter

This commit is contained in:
Fazzie 2021-09-29 11:43:49 +08:00
parent ad88cfcafb
commit 04d45a701c
16 changed files with 300 additions and 141 deletions

View File

@ -55,12 +55,13 @@ constexpr int kMaxDepth = 2048;
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
std::vector<AnfNodePtr> vecs;
std::vector<AnfNodePtr> vecs{};
if (node == nullptr) {
return vecs;
}
if (node->isa<mindspore::CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_ASSERT(cnode != nullptr);
auto &inputs = cnode->inputs();
// Check if free variables used.
for (const auto &input : inputs) {
@ -78,7 +79,7 @@ std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
return vecs;
};
std::list<CNodePtr> cnodes;
std::list<CNodePtr> cnodes{};
auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
for (const auto &node : nodes) {
auto cnode = dyn_cast<mindspore::CNode>(node);
@ -100,10 +101,7 @@ int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptr<schema::Meta
first_tensor_output->dataType = kNumberTypeInt8;
} else {
auto primc = primitive->cast<std::shared_ptr<mindspore::ops::QuantDTypeCast>>();
if (primc == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr.";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(primc != nullptr, RET_ERROR, "cast ptr failed");
if (primc->get_dst_t() != kNumberTypeFloat32) {
first_tensor_output->dataType = kNumberTypeInt8;
}
@ -220,6 +218,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
const AnfNodePtr &input) {
lite::DataInfo data_info;
auto param_node = input->cast<ParameterPtr>();
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_NULL_PTR, "cast ptr failed");
if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info) != RET_OK) {
MS_LOG(ERROR) << "FetchFromDefaultParam failed.";
return RET_ERROR;
@ -242,7 +241,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
int AnfExporter::SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index) {
auto &subgraph = meta_graphT->subGraph.at(subgraph_index);
FuncGraphPtr fg;
FuncGraphPtr fg = nullptr;
std::for_each(fg_subgraph_map_.begin(), fg_subgraph_map_.end(),
[&subgraph_index, &fg](const std::pair<const FuncGraphPtr, size_t> &it) {
if (it.second == subgraph_index) {
@ -317,6 +316,7 @@ int AnfExporter::ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &me
const bool &copy_primitive, const CNodePtr &partial_cnode,
const std::unique_ptr<schema::CNodeT> &schema_cnode) {
auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(partial_cnode->input(0));
MS_CHECK_TRUE_MSG(prim != nullptr, RET_NULL_PTR, "GetValueNode failed");
if (prim->name() != mindspore::ops::kNamePartialFusion) {
MS_LOG(INFO) << "not is partial";
return RET_OK;
@ -324,13 +324,10 @@ int AnfExporter::ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &me
auto partial_fusion_primc = schema_cnode->primitive->value.AsPartialFusion();
auto vnode = partial_cnode->input(kFirstDataIndex)->cast<ValueNodePtr>();
MS_ASSERT(vnode != nullptr);
MS_CHECK_TRUE_MSG(partial_fusion_primc != nullptr, RET_NULL_PTR, "partial_fusion_primc is invalid");
MS_CHECK_TRUE_MSG(vnode != nullptr, RET_NULL_PTR, "vnode is invalid");
auto fg = vnode->value()->cast<FuncGraphPtr>();
if (fg == nullptr) {
MS_LOG(ERROR) << "func graph is nullptr.";
return RET_NULL_PTR;
}
MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "func graph is nullptr.");
if (fg_subgraph_map_.find(fg) != fg_subgraph_map_.end()) {
partial_fusion_primc->sub_graph_index = fg_subgraph_map_.at(fg);
return RET_OK;
@ -376,10 +373,6 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
for (const auto &cnode : cnodes) {
auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
std::unique_ptr<schema::PrimitiveT> primT;
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr.";
return RET_ERROR;
}
ret = RemoveIfDepend(cnode);
if (ret != RET_OK) {
@ -513,10 +506,13 @@ FuncGraphPtr GetFinalGraph(const FuncGraphPtr &func_graph, int i) {
auto cnode = call_cnode->input(kFirstDataIndex)->cast<CNodePtr>();
if (opt::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
auto false_cnode = cnode->input(kSwitchFalseIndex)->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(false_cnode != nullptr, nullptr, "cast failed");
auto false_fg = GetValueNode<FuncGraphPtr>(false_cnode->input(kFirstDataIndex));
MS_CHECK_TRUE_MSG(false_fg != nullptr, nullptr, "GetValueNode failed");
return GetFinalGraph(false_fg, i);
} else {
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kFirstDataIndex));
MS_CHECK_TRUE_MSG(fg != nullptr, nullptr, "GetValueNode failed");
return GetFinalGraph(fg, i);
}
@ -542,10 +538,7 @@ int AnfExporter::SetMetaGraphOutput(const FuncGraphPtr &func_graph,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
int i = 0;
auto final_fg = GetFinalGraph(func_graph, i);
if (final_fg == nullptr) {
MS_LOG(ERROR) << "GetFinalGraph failed.";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(final_fg != nullptr, RET_ERROR, "GetFinalGraph failed.");
auto final_meta_graph_index = fg_subgraph_map_.at(final_fg);
auto &final_meta_graph = meta_graphT->subGraph.at(final_meta_graph_index);
meta_graphT->outputIndex.assign(final_meta_graph->outputIndices.begin(), final_meta_graph->outputIndices.end());
@ -605,10 +598,7 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema
}
if (utils::isa<abstract::AbstractTuple>(input_anode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input_anode->abstract());
if (tuple == nullptr) {
MS_LOG(ERROR) << "tuple is nullptr";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
auto elements = tuple->elements();
for (size_t i = 0; i < elements.size(); i++) {
auto key = std::make_pair(input_anode, i);
@ -627,6 +617,7 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema
int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) {
auto input_cnode = utils::cast<CNodePtr>(input_anode);
MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "cast ptr failed");
auto input_value_node = input_cnode->input(kPrimIndex)->cast<ValueNodePtr>();
if (input_value_node == nullptr) {
if (!IsCall(input_cnode)) {
@ -635,6 +626,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
} else {
auto call_anf_prim_vnode = GetCallAnfPrim();
auto cnode_input = input_cnode->inputs();
MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, RET_ERROR, "GetCallAnfPrim failed");
cnode_input.insert(cnode_input.begin(), call_anf_prim_vnode);
input_cnode->set_inputs(cnode_input);
}
@ -658,10 +650,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
return RET_ERROR;
}
auto value_node = utils::cast<ValueNodePtr>(index_vnode);
if (value_node == nullptr) {
MS_LOG(ERROR) << "cast to ValueNode failed";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "cast to ValueNode failed");
auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value())
: GetValue<int>(value_node->value());
auto key = std::make_pair(get_item_input_cnode, idx);
@ -794,10 +783,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
if (tuple == nullptr) {
MS_LOG(ERROR) << "tuple is nullptr";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
auto elements = tuple->elements();
for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) {
auto ms_tensor = new (std::nothrow) schema::TensorT();
@ -841,6 +827,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
ms_tensor->dataType = type_ptr->type_id();
meta_graphT->allTensors.emplace_back(ms_tensor);
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
@ -881,8 +868,10 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
CNodePtr AnfExporter::CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
auto call_anf_prim_vnode = GetCallAnfPrim();
MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, nullptr, "GetCallAnfPrim failed");
std::vector<AnfNodePtr> inputs{call_anf_prim_vnode, node};
auto cnode = fg->NewCNodeInOrder(inputs);
MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "NewCNode failed");
cnode->set_func_graph(fg);
return cnode;
}
@ -890,19 +879,23 @@ CNodePtr AnfExporter::CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &
CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cast ptr failed");
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(kPrimIndex));
if (primitive_c != nullptr) {
return cnode;
}
auto partial_anf_prim_vnode = GetPartialFusionPrim();
auto cnode_input = cnode->inputs();
MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode);
cnode->set_inputs(cnode_input);
return cnode;
} else if (utils::isa<ValueNodePtr>(node)) {
auto partial_anf_prim_vnode = GetPartialFusionPrim();
MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node};
auto cnode = fg->NewCNode(inputs);
MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "New cnode failed");
return cnode;
} else {
MS_LOG(ERROR) << "failed to create partial cnode.";

View File

@ -91,6 +91,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, Sh
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
*data_type = typePtr->type_id();
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
@ -105,12 +106,14 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
bool train_flag, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
auto valueAbstract = value_node->abstract();
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
return RET_ERROR;
}
auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
data_info->data_type_ = typePtr->type_id();
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
@ -121,6 +124,7 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
auto value = value_node->value();
MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
auto data = value->cast<tensor::TensorPtr>();
MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
data_info->data_.resize(data->Size());
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
@ -159,6 +163,7 @@ int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &pr
auto value = value_node->value();
MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
auto data = value->cast<mindspore::BoolImmPtr>();
MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr");
auto data_value = data->value();
if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
@ -173,6 +178,7 @@ int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
data_info->shape_ = {1};
data_info->data_.resize(sizeof(int));
auto data = value_node->value()->cast<NumberPtr>();
MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberPtr failed");
int number_type = data->number_type();
if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) {
number_type = TypeToTypeMap.at(number_type);
@ -256,16 +262,13 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto param_node = cnode->input(index)->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "input node is not parameter node.";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
MS_LOG(ERROR) << "fetch information from default param failed.";
return RET_ERROR;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
if (prim->GetAttr(ops::kFormat) == nullptr && !param_node->has_default()) {
data_info->format_ = mindspore::NHWC;
}
@ -289,10 +292,8 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto value_node = cnode->input(index)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "input node is not value node.";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
auto value = value_node->value();
int ret = RET_OK;
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -328,6 +329,7 @@ int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fm
DataInfo *data_info) {
data_info->format_ = mindspore::NHWC;
auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed");
if (input_node_prim->GetAttr(ops::kFormat) != nullptr) {
auto value = input_node_prim->GetAttr(ops::kFormat);
if (value->isa<mindspore::Int64Imm>()) {
@ -367,6 +369,7 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
@ -390,6 +393,7 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
return RET_ERROR;
}
auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
if (tensor_value->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_value->Size());
if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
@ -410,21 +414,21 @@ int RemoveIfDepend(const CNodePtr &cnode) {
inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr inputNode = cnode->input(i);
MS_CHECK_TRUE_MSG(inputNode != nullptr, RET_NULL_PTR, "inputNode is nullptr");
if (!inputNode->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto depend_node = utils::cast<CNodePtr>(inputNode);
MS_CHECK_TRUE_MSG(depend_node != nullptr, RET_NULL_PTR, "depend_node is nullptr");
auto value_node = depend_node->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
has_depend = true;
bool mask_out = (depend_node->inputs().size() == opt::kInputSizeThree);
for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
AnfNodePtr depend_input_node = depend_node->input(j);
MS_CHECK_TRUE_MSG(depend_input_node != nullptr, RET_NULL_PTR, "depend_input_node is nullptr");
if (depend_input_node->isa<CNode>()) {
inputs.emplace_back(depend_input_node);
if (mask_out) {
@ -450,16 +454,15 @@ int RemoveIfMakeTuple(const CNodePtr &cnode) {
inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr input_node = cnode->input(i);
MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "input_node is nullptr");
if (!input_node->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto make_tuple_node = utils::cast<CNodePtr>(input_node);
MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, RET_NULL_PTR, "make_tuple_node is nullptr");
auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) ||
opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) {
has_make_tuple = true;

View File

@ -30,7 +30,7 @@ Option<std::string> FlagParser::ParseFlags(int argc, const char *const *argv, bo
}
binName = GetFileName(argv[0]);
std::multimap<std::string, Option<std::string>> keyValues;
std::multimap<std::string, Option<std::string>> keyValues{};
for (int i = 1; i < argc; i++) {
std::string tmp = argv[i];
Trim(&tmp);

View File

@ -20,7 +20,6 @@
#include <vector>
#include <map>
#include <queue>
#include <utility>
#include "src/common/log_adapter.h"
#include "tools/common/node_util.h"
#include "tools/common/graph_util.h"
@ -29,18 +28,43 @@
#include "nnacl/op_base.h"
namespace mindspore::lite {
SubGraph::SubGraph(FuncGraphPtr belong_anf, std::string graph_name, const std::set<CNodePtr> &head_nodes)
: belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {
InitSubGraphNode(head_nodes);
InitSubGraphInNode();
InitSubGraphOutNode();
int SubGraph::Init(const std::set<CNodePtr> &head_nodes) {
auto ret = InitSubGraphNode(head_nodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSubGraphNode failed";
return RET_ERROR;
}
ret = InitSubGraphInNode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSubGraphInNode failed";
return RET_ERROR;
}
ret = InitSubGraphOutNode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSubGraphOutNode failed";
return RET_ERROR;
}
return RET_OK;
}
void SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
int SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
this->nodes_ = nodes;
InitSubGraphNode(head_nodes);
InitSubGraphInNode();
InitSubGraphOutNode();
auto ret = InitSubGraphNode(head_nodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSubGraphNode failed";
return RET_ERROR;
}
ret = InitSubGraphInNode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSubGraphInNode failed";
return RET_ERROR;
}
ret = InitSubGraphOutNode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitSubGraphOutNode failed";
return RET_ERROR;
}
return RET_OK;
}
std::set<CNodePtr> SubGraph::GetNodes() const { return this->nodes_; }
@ -125,11 +149,11 @@ std::set<CNodePtr> SubGraph::GetOutputCNodes() const {
return outputs;
}
void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
int SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
MS_ASSERT(belong_anf_ != nullptr);
MS_ASSERT(belong_anf_->manager() != nullptr);
auto node_users = belong_anf_->manager()->node_users();
std::queue<CNodePtr> q;
std::queue<CNodePtr> q{};
for (const auto &head_node : head_nodes) {
if (head_node == nullptr) {
continue;
@ -138,7 +162,7 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
}
while (!q.empty()) {
auto cur_node = q.front();
MS_CHECK_PTR_IF_NULL(cur_node);
MS_CHECK_TRUE_MSG(cur_node != nullptr, RET_NULL_PTR, "cur_node is nullptr");
q.pop();
this->nodes_.insert(cur_node);
// check output-cnode of cur-node only depend on cur-node
@ -153,11 +177,12 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
continue;
}
auto post_cnode = utils::cast<CNodePtr>(post_node);
MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "cast failed");
// return-node should not be include into subgraph absolutely // ut
if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) {
continue;
}
MS_CHECK_PTR_IF_NULL(post_cnode);
MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr");
bool non_depend = true;
// check all inputs of output-cnode
for (const auto &input : post_cnode->inputs()) {
@ -167,6 +192,7 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
// input cnode is not contained in subgraph
if (utils::isa<CNodePtr>(input)) {
auto input_cnode = utils::cast<CNodePtr>(input);
MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
if (this->nodes_.count(input_cnode) == 0) {
non_depend = false;
break;
@ -175,6 +201,7 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
// input parameter is a graph input
if (utils::isa<ParameterPtr>(input)) {
auto input_parameter = utils::cast<ParameterPtr>(input);
MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_NULL_PTR, "cast failed");
if (!input_parameter->has_default()) {
non_depend = false;
break;
@ -186,9 +213,10 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
}
}
}
return RET_OK;
}
void SubGraph::InitSubGraphInNode() {
int SubGraph::InitSubGraphInNode() {
MS_ASSERT(belong_anf_ != nullptr);
MS_ASSERT(belong_anf_->manager() != nullptr);
auto node_users = belong_anf_->manager()->node_users();
@ -203,6 +231,7 @@ void SubGraph::InitSubGraphInNode() {
}
if (utils::isa<CNodePtr>(input)) {
auto input_cnode = utils::cast<CNodePtr>(input);
MS_CHECK_TRUE_MSG(input_cnode != nullptr, false, "cast failed");
if (this->nodes_.count(input_cnode) == 0) {
return true;
}
@ -210,6 +239,7 @@ void SubGraph::InitSubGraphInNode() {
// graph input or shared weight input // ut
if (utils::isa<ParameterPtr>(input)) {
auto input_parameter = utils::cast<ParameterPtr>(input);
MS_CHECK_TRUE_MSG(input_parameter != nullptr, false, "cast failed");
if (!input_parameter->has_default()) {
return true;
}
@ -223,9 +253,10 @@ void SubGraph::InitSubGraphInNode() {
in_nodes_.insert(node);
}
}
return RET_OK;
}
void SubGraph::InitSubGraphOutNode() {
int SubGraph::InitSubGraphOutNode() {
MS_ASSERT(belong_anf_ != nullptr);
MS_ASSERT(belong_anf_->manager() != nullptr);
auto node_users = belong_anf_->manager()->node_users();
@ -250,6 +281,7 @@ void SubGraph::InitSubGraphOutNode() {
return true;
}
auto output_cnode = utils::cast<CNodePtr>(output_node);
MS_CHECK_TRUE_MSG(output_cnode != nullptr, false, "cast failed");
if (this->nodes_.count(output_cnode) == 0) {
return true;
}
@ -258,6 +290,7 @@ void SubGraph::InitSubGraphOutNode() {
continue;
out_nodes_.insert(node);
}
return RET_OK;
}
bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
@ -272,7 +305,10 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
auto new_nodes2 = subgraph->GetNodes();
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
new_nodes.insert(common_outputs.begin(), common_outputs.end());
this->Reset(new_nodes, common_outputs);
if (this->Reset(new_nodes, common_outputs) != RET_OK) {
MS_LOG(ERROR) << "Reset failed";
return false;
}
return true;
}
@ -280,7 +316,10 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
auto new_nodes = this->GetNodes();
auto new_nodes2 = subgraph->GetNodes();
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
this->Reset(new_nodes);
if (this->Reset(new_nodes) != RET_OK) {
MS_LOG(ERROR) << "Reset failed";
return false;
}
return true;
}
return false;
@ -290,8 +329,8 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
MS_ASSERT(belong_anf_ == nullptr);
// find before subgraph's nodes
std::queue<CNodePtr> q;
std::set<CNodePtr> before_nodes;
std::queue<CNodePtr> q{};
std::set<CNodePtr> before_nodes{};
for (const auto &node : this->GetInCNodes()) {
for (const auto &in_cnode : lite::GetInputCNode(node)) {
if (in_cnode == nullptr) {
@ -311,7 +350,11 @@ SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
}
// construct before subgraph
auto before_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/before_subgraph");
before_subgraph->Reset(before_nodes);
MS_CHECK_TRUE_MSG(before_subgraph != nullptr, nullptr, "before_subgraph is nullptr");
if (before_subgraph->Reset(before_nodes) != RET_OK) {
MS_LOG(ERROR) << "Reset failed";
return nullptr;
}
return before_subgraph;
}
@ -325,8 +368,8 @@ SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
return nullptr;
}
// find after subgraph's nodes
std::queue<CNodePtr> q;
std::set<CNodePtr> after_nodes;
std::queue<CNodePtr> q{};
std::set<CNodePtr> after_nodes{};
auto output_node = belong_anf_->output();
if (!utils::isa<CNodePtr>(output_node)) {
MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope();
@ -348,7 +391,12 @@ SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
}
// construct before subgraph
auto after_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/after_subgraph");
after_subgraph->Reset(after_nodes);
MS_CHECK_TRUE_MSG(after_subgraph != nullptr, nullptr, "after_subgraph is nullptr");
if (after_subgraph->Reset(after_nodes) != RET_OK) {
MS_LOG(ERROR) << "Reset failed";
return nullptr;
}
return after_subgraph;
}
@ -366,6 +414,7 @@ int SubGraph::CreatePartialInBelongAnf() {
}
// create func_graph of partial
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
auto manager = belong_anf_->manager();
manager->AddFuncGraph(func_graph);
func_graph->set_attr("graph_name", MakeValue(graph_name));
@ -373,12 +422,20 @@ int SubGraph::CreatePartialInBelongAnf() {
// create cnode and parameter for func_graph of partial
std::vector<AnfNodePtr> partial_inputs;
std::map<AnfNodePtr, AnfNodePtr> partial_inputs_and_subgraph_input_map;
CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
auto ret = CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CreateParameterForPartialSubGraph failed";
return ret;
}
ret = CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CreateCNodeForPartialSubGraph failed";
return ret;
}
// add return for func_graph of partial
auto sub_graph_outputs = this->GetOutCNodes();
MS_ASSERT(!sub_graph_outputs.empty());
auto ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Set subgraph output failed";
return ret;
@ -386,8 +443,11 @@ int SubGraph::CreatePartialInBelongAnf() {
// create partial cnode
auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
auto graph_value_node = NewValueNode(func_graph);
MS_CHECK_TRUE_MSG(partial_prim != nullptr, RET_NULL_PTR, "partial_prim is nullptr");
MS_CHECK_TRUE_MSG(graph_value_node != nullptr, RET_NULL_PTR, "graph_value_node is nullptr");
partial_inputs.insert(partial_inputs.begin(), graph_value_node);
auto partial_cnode = belong_anf_->NewCNode(partial_prim, partial_inputs);
MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "partial_cnode is nullptr");
partial_cnode->set_fullname_with_scope(graph_name + "/partial");
for (size_t i = 0; i < partial_inputs.size(); ++i) {
const auto &input = partial_inputs.at(i);
@ -396,6 +456,7 @@ int SubGraph::CreatePartialInBelongAnf() {
// create call cnode
std::vector<AnfNodePtr> call_node_inputs{partial_cnode};
auto call_cnode = belong_anf_->NewCNode(call_node_inputs);
MS_CHECK_TRUE_MSG(call_cnode != nullptr, RET_NULL_PTR, "call_cnode is nullptr");
call_cnode->set_fullname_with_scope(graph_name + "/call");
// replace belong-graph's output
auto return_node = belong_anf_->get_return();
@ -412,7 +473,7 @@ int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNode
return lite::SetFuncGraphOutput(graph, output_nodes);
}
void SubGraph::CreateParameterForPartialSubGraph(
int SubGraph::CreateParameterForPartialSubGraph(
const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map) {
MS_ASSERT(sub_graph != nullptr);
@ -436,6 +497,7 @@ void SubGraph::CreateParameterForPartialSubGraph(
// create subgraph input parameter from cnode and record partial inputs
if (utils::isa<CNodePtr>(input)) {
auto input_cnode = utils::cast<CNodePtr>(input);
MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
if (this->GetNodes().count(input_cnode) > 0) {
continue;
}
@ -450,6 +512,7 @@ void SubGraph::CreateParameterForPartialSubGraph(
auto node_users = this->belong_anf_->manager()->node_users();
if (utils::isa<ParameterPtr>(input)) {
auto parameter = utils::cast<ParameterPtr>(input);
MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "cast ptr failed");
// graph input: create a parameter
if (!parameter->has_default()) {
auto new_parameter = sub_graph->add_parameter();
@ -475,9 +538,10 @@ void SubGraph::CreateParameterForPartialSubGraph(
}
}
}
return RET_OK;
}
void SubGraph::CreateCNodeForPartialSubGraph(
int SubGraph::CreateCNodeForPartialSubGraph(
const FuncGraphPtr &sub_graph, const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map) {
MS_ASSERT(sub_graph != nullptr);
// move cnode from belong_graph to subgraph
@ -500,6 +564,7 @@ void SubGraph::CreateCNodeForPartialSubGraph(
}
this->belong_anf_->DropNode(node);
}
return RET_OK;
}
int SubGraph::ApplySubGraph() {
@ -546,7 +611,10 @@ int SubGraph::ApplySubGraph() {
MS_ASSERT(after_partial_cnode != nullptr);
subgraph_nodes.insert(after_partial_cnode);
subgraph_nodes.insert(call_cnode);
this->Reset(subgraph_nodes);
if (this->Reset(subgraph_nodes) != RET_OK) {
MS_LOG(ERROR) << "Reset failed";
return false;
}
// create subgraph partial // add partial to main subgraph
ret = this->CreatePartialInBelongAnf();
if (ret != RET_OK) {

View File

@ -22,6 +22,7 @@
#include <vector>
#include <map>
#include <set>
#include <utility>
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "ir/anf.h"
@ -32,9 +33,11 @@ class SubGraph;
using SubGraphPtr = std::shared_ptr<SubGraph>;
class SubGraph {
public:
explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "", const std::set<CNodePtr> &head_nodes = {});
explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "")
: belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {}
void Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {});
int Init(const std::set<CNodePtr> &head_nodes = {});
int Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {});
bool MergeSubGraph(const SubGraphPtr &subgraph);
@ -48,19 +51,19 @@ class SubGraph {
std::set<CNodePtr> GetInputCNodes() const;
std::set<CNodePtr> GetOutputCNodes() const;
// init subgraph methods
void InitSubGraphNode(const std::set<CNodePtr> &head_nodes);
void InitSubGraphInNode();
void InitSubGraphOutNode();
int InitSubGraphNode(const std::set<CNodePtr> &head_nodes);
int InitSubGraphInNode();
int InitSubGraphOutNode();
// merge subgraph methods
std::set<CNodePtr> FindCommonOutputs(const SubGraphPtr &subgraph) const;
bool IfDependOnSameNode(const SubGraphPtr &subgraph) const;
// apply subgraph methods
SubGraphPtr FindBeforeSubGraphInBelongAnf() const;
SubGraphPtr FindAfterSubGraphInBelongAnf() const;
void CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map);
void CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph,
const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map);
int CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map);
int CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph,
const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map);
int CreatePartialInBelongAnf();
static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs);

View File

@ -27,6 +27,7 @@
#include "mindspore/core/ops/switch.h"
#include "mindspore/core/ops/call.h"
#include "mindspore/core/ops/fusion/partial_fusion.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -332,6 +333,7 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
MS_ASSERT(anf_node != nullptr);
auto cnode = anf_node->cast<CNodePtr>();
MS_ASSERT(cnode != nullptr);
if (train_flag &&
(opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
return 1;
@ -350,6 +352,7 @@ bool IsPartialFusion(const AnfNodePtr &node) {
}
if (node->isa<mindspore::CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
auto vnode_value = cnode->input(0)->cast<ValueNodePtr>()->value();
return GetValue<NamedPtr>(vnode_value)->name() == "PartialFusion";
}
@ -364,6 +367,7 @@ bool IsCall(const AnfNodePtr &node) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
if (cnode->inputs().empty()) {
return false;
}
@ -400,19 +404,25 @@ bool IsMakeTuple(const AnfNodePtr &node) {
ValueNodePtr GetPartialFusionPrim() {
auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr");
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr");
return partial_anf_prim;
}
ValueNodePtr GetSwitchAnfPrim() {
auto switch_prim = std::make_shared<mindspore::ops::Switch>();
MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
ValueNodePtr switch_anf_prim = NewValueNode(switch_prim);
MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
return switch_anf_prim;
}
ValueNodePtr GetCallAnfPrim() {
auto call_prim = std::make_shared<mindspore::ops::Call>();
MS_CHECK_TRUE_MSG(call_prim != nullptr, nullptr, "call_prim is nullptr");
ValueNodePtr call_anf_prim = NewValueNode(call_prim);
MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
return call_anf_prim;
}
} // namespace lite

View File

@ -19,6 +19,7 @@
#include "mindspore/lite/tools/common/string_util.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
namespace mindspore {
namespace lite {
int ReadFileToIfstream(const std::string &file_path, std::ifstream *ifstream) {

View File

@ -18,6 +18,7 @@
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "abstract/utils.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor) {
@ -131,6 +132,7 @@ std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::Tenso
return nullptr;
}
auto schema_tensor = std::make_unique<schema::TensorT>();
MS_CHECK_TRUE_MSG(schema_tensor != nullptr, nullptr, "schema_tensor is nullptr");
schema_tensor->name = tensor_name;
auto ret = UpdateTensorTFromTensorInfo(tensor_info, &schema_tensor);
if (ret != RET_OK) {

View File

@ -37,6 +37,7 @@ static bool IsSupportedNode(const BaseRef &n) {
};
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
MS_ASSERT(anf_node != nullptr);
return std::any_of(support_list.begin(), support_list.end(),
[&anf_node](const auto &primitive) { return CheckPrimitiveType(anf_node, primitive); });
}
@ -95,6 +96,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
}
auto new_return_node = func_graph->NewCNode({NewValueNode(return_prim_ptr), make_tuple_cnode});
new_return_node->set_fullname_with_scope(return_cnode->fullname_with_scope());
MS_ASSERT(new_return_node != nullptr);
func_graph->set_return(new_return_node);
MS_ASSERT(new_return_node != nullptr);
@ -103,7 +105,9 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
const BaseRef AddTensorArray::DefinePattern() const {
auto support_detect = std::make_shared<CondVar>(IsSupportedNode);
MS_ASSERT(support_detect != nullptr);
auto inputs_var = std::make_shared<SeqVar>();
MS_ASSERT(inputs_var != nullptr);
return VectorRef({support_detect, inputs_var});
}
@ -137,6 +141,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
MS_ASSERT(abstract_tensor != nullptr);
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
return nullptr;
@ -168,6 +173,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
NewValueNode(tensor_array),
NewValueNode(kDefaultNumTensors),
});
MS_ASSERT(tensor_array_node != nullptr);
tensor_array_node->set_abstract(abstract->Clone());
tensor_array_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_tensor_array");

View File

@ -168,6 +168,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow
}
visited_nodes->insert(node);
auto cnode = utils::cast<CNodePtr>(node);
MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cast ptr failed");
for (size_t i = 0; i < cnode->inputs().size(); i++) {
auto input = cnode->input(i);
if (visited_nodes->find(input) == visited_nodes->end()) {
@ -191,6 +192,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow
int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
const CNodePtr &aim_cnode, FuncGraphPtr *after_fg) {
*after_fg = std::make_shared<FuncGraph>();
MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
auto manager = main_fg->manager();
manager->AddFuncGraph(*after_fg);
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
@ -282,6 +284,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
CNodePtr *body_partial_node) {
auto body_vnode = while_cnode->input(kWhileBodyIndex);
MS_CHECK_TRUE_MSG(body_vnode != nullptr, lite::RET_NULL_PTR, "body_vnode is nullptr");
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
if (body_fg == nullptr) {
MS_LOG(ERROR) << "Get value as func_graph failed.";
@ -469,13 +472,14 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
if (switch_anf_primitive == nullptr) {
MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
return false;
return lite::RET_ERROR;
}
// insert switch node
std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
after_partial_cnode};
auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_ERROR, "NewCnode failed");
switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString());
// insert call node
@ -590,6 +594,7 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &i
after_partial_cnode_inputs.push_back(then_fg->output());
} else {
auto then_fg_output = then_fg->output()->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(then_fg_output != nullptr, RET_ERROR, "cast ptr failed");
for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->inputs().size(); ++i) {
after_partial_cnode_inputs.push_back(then_fg_output->input(i));
}

View File

@ -84,6 +84,7 @@ STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNode
return lite::RET_ERROR;
}
auto cur_node_post = cur_node_user.first->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cur_node_post != nullptr, RET_ERROR, "cast ptr failed");
if (middle_nodes->find(cur_node_post) != middle_nodes->end() ||
out_nodes->find(cur_node_post) != out_nodes->end()) {
continue;
@ -138,6 +139,7 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<
continue;
}
auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0));
MS_CHECK_TRUE_MSG(middle_node_prim != nullptr, false, "GetValueNode failed");
if (!transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) {
return false;
}
@ -145,30 +147,30 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<
return true;
}
void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
bool train_flag, FormatTransNodeType trans_type) {
int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
bool train_flag, FormatTransNodeType trans_type) {
MS_ASSERT(cnode != nullptr);
if (utils::isa<CNodePtr>(cnode->input(index))) {
return;
return lite::RET_OK;
}
lite::DataInfo data_info;
int status;
int status = 0;
if (utils::isa<ParameterPtr>(cnode->input(index))) {
auto input_node = cnode->input(index)->cast<ParameterPtr>();
MS_ASSERT(input_node != nullptr);
MS_CHECK_TRUE_MSG(input_node != nullptr, lite::RET_ERROR, "input_node is nullptr");
if (!input_node->has_default()) {
return;
return lite::RET_OK;
}
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
} else {
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
}
if (status != lite::RET_OK) {
return;
return lite::RET_ERROR;
}
if (data_info.shape_.empty() ||
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
return;
return lite::RET_OK;
}
ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
if (data_info.shape_.size() == 1) {
@ -180,22 +182,24 @@ void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode
}
auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
data_info.data_.data(), data_info.data_.size());
MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr");
if (trans_type == kNHWC2NCHW) {
(void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
} else {
(void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
}
auto param_node = func_graph->add_parameter();
MS_CHECK_TRUE_MSG(param_node != nullptr, , "add_parameter failed");
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add_parameter failed");
param_node->set_name(cnode->input(index)->fullname_with_scope());
status = lite::InitParameterFromTensorInfo(param_node, tensor);
if (status != RET_OK) {
MS_LOG(ERROR) << "init parameter from tensor info failed";
return;
return lite::RET_ERROR;
}
auto tr = func_graph->manager()->Transact();
tr.SetEdge(cnode, index, param_node);
tr.Commit();
return lite::RET_OK;
}
} // namespace
@ -273,6 +277,7 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
MS_CHECK_TRUE_MSG(prim_node != nullptr, lite::RET_ERROR, "prim_node is nullptr");
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
auto &specify_nhwc_op_map = GetNHWCOpMap();
@ -346,7 +351,11 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
return lite::RET_ERROR;
}
}
ModifyCNodeFormat(cnode, trans_insert_info->pre_);
status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
return lite::RET_ERROR;
}
status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
@ -439,14 +448,22 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
continue;
}
for (size_t i = 1; i < middle_cnode->size(); ++i) {
ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
status = ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
return lite::RET_ERROR;
}
}
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
ModifyCNodeFormat(middle_cnode, trans_info.post_);
status = ModifyCNodeFormat(middle_cnode, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
return lite::RET_ERROR;
}
status = node_infer_shape_.InferShape(middle_cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
@ -456,7 +473,7 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
return lite::RET_OK;
}
void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto sub_inputs = sub_graph->get_inputs();
sub_inputs_map_[sub_graph] = sub_inputs;
@ -475,6 +492,7 @@ void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGr
MS_ASSERT(out_cnode != nullptr);
MS_ASSERT(trans_cnode != nullptr);
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
}
@ -499,9 +517,10 @@ void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGr
}
}
}
return lite::RET_OK;
}
void DecreaseTransposeAlgo::ResetSubGraphInput() {
int DecreaseTransposeAlgo::ResetSubGraphInput() {
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
auto &sub_graph = iter->first;
auto &sub_inputs = iter->second;
@ -509,7 +528,7 @@ void DecreaseTransposeAlgo::ResetSubGraphInput() {
MS_ASSERT(manager != nullptr);
for (auto &sub_input : sub_inputs) {
auto param_node = sub_graph->add_parameter();
MS_CHECK_TRUE_MSG(param_node != nullptr, , "add parameter failed");
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add parameter failed");
param_node->set_abstract(sub_input->abstract()->Clone());
param_node->set_name(sub_input->fullname_with_scope());
manager->Replace(sub_input, param_node);
@ -518,9 +537,10 @@ void DecreaseTransposeAlgo::ResetSubGraphInput() {
sub_param_input->set_default_param(nullptr);
}
}
return lite::RET_OK;
}
void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
int DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
MS_ASSERT(return_node != nullptr);
@ -548,11 +568,13 @@ void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncG
trans_cnode->set_fullname_with_scope(trans_input_name);
}
return_node->set_inputs(origin_input);
return lite::RET_OK;
}
void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
int DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
MS_CHECK_TRUE_MSG(return_node != nullptr, lite::RET_ERROR, "return_node is nullptr");
auto origin_inputs = return_node->inputs();
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
@ -560,7 +582,7 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
bool infer_done = true;
for (size_t i = 1; i < return_node->size(); ++i) {
auto abstract_base = GetCNodeInputAbstract(return_node, i);
MS_CHECK_TRUE_MSG(abstract_base != nullptr, , "GetCNodeInputAbstract failed");
MS_CHECK_TRUE_MSG(abstract_base != nullptr, lite::RET_ERROR, "GetCNodeInputAbstract failed");
abstract_list.emplace_back(abstract_base->Clone());
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
MS_ASSERT(abstract_tensor != nullptr);
@ -572,11 +594,12 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
}
if (utils::isa<CNodePtr>(return_node->input(i))) {
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(input_cnode != nullptr, lite::RET_ERROR, "input_cnode is nullptr");
if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
}
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_CHECK_TRUE_MSG(input_prim != nullptr, , "GetValueNode failed");
MS_CHECK_TRUE_MSG(input_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
infer_done = false;
}
@ -592,22 +615,25 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
cnode->set_abstract(abstract_list.front());
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_MSG(prim != nullptr, , "GetValueNode Failed");
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
return lite::RET_OK;
}
void DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
int DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
MS_ASSERT(cnode != nullptr);
if (pre_trans_type == kNONE) {
return;
return lite::RET_OK;
}
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_MSG(primitive != nullptr, , "GetValueNode Failed");
MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_ERROR, "GetValueNode Failed");
if (pre_trans_type == kNHWC2NCHW) {
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
} else {
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
}
return lite::RET_OK;
}
bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
@ -618,7 +644,7 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun
return false;
}
auto node_list = TopoSort(func_graph->get_return());
int status;
int status = 0;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
@ -634,18 +660,38 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
auto ret = SetSubGraphInput(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphInput failed";
return false;
}
(void)DecreaseTransposeForSingleOp(sub_func_graph);
SetSubGraphOutput(cnode, sub_func_graph);
ret = SetSubGraphOutput(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphOutput failed";
return false;
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
ret = SetSubGraphInput(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphInput failed";
return false;
}
(void)DecreaseTransposeForSingleOp(sub_func_graph);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
ret = SetSubGraphOutput(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphOutput failed";
return false;
}
ret = SetSubGraphAbstract(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphAbstract failed";
return false;
}
continue;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -726,7 +772,12 @@ bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
return false;
}
ResetSubGraphInput();
auto ret = ResetSubGraphInput();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "ResetSubGraphInput failed.";
return false;
}
// if input format of several ops surrounded only by transpose op all can be NHWC,
// we can delete these transpose ops, and at the same time, transform these middle ops.
if (!DecreaseTransposeForMultiOp(func_graph)) {

View File

@ -51,11 +51,11 @@ class DecreaseTransposeAlgo : public Pass {
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ResetSubGraphInput();
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
int SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
int ResetSubGraphInput();
int SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
int SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
int ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
NodeInferShape node_infer_shape_;

View File

@ -111,7 +111,11 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "infer shape failed.";
return false;
}
return ResetSubGraphInput();
if (ResetSubGraphInput() != lite::RET_OK) {
MS_LOG(ERROR) << "ResetSubGraphInput failed.";
return false;
}
return true;
}
bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
@ -187,7 +191,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
if (SetSubGraphOutput(cnode, sub_func_graph)) {
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
return false;
}
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
@ -202,7 +209,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
if (SetSubGraphOutput(cnode, sub_func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
return false;
}
ret = SetSubGraphAbstract(cnode, sub_func_graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret;
@ -278,7 +288,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
return RET_OK;
}
void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
STATUS InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
MS_ASSERT(return_node != nullptr);
@ -307,6 +317,7 @@ void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr
trans_cnode->set_fullname_with_scope(trans_input_name);
}
return_node->set_inputs(origin_input);
return lite::RET_OK;
}
STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
@ -360,7 +371,7 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap
return RET_OK;
}
bool InferShapePass::ResetSubGraphInput() {
int InferShapePass::ResetSubGraphInput() {
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
auto &sub_graph = iter->first;
auto &sub_inputs = iter->second;
@ -368,7 +379,7 @@ bool InferShapePass::ResetSubGraphInput() {
MS_ASSERT(manager != nullptr);
for (auto &sub_input : sub_inputs) {
auto param_node = sub_graph->add_parameter();
MS_CHECK_TRUE_MSG(param_node != nullptr, false, "Add parameter Failed");
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "Add parameter Failed");
param_node->set_abstract(sub_input->abstract()->Clone());
param_node->set_name(sub_input->fullname_with_scope());
manager->Replace(sub_input, param_node);
@ -377,7 +388,7 @@ bool InferShapePass::ResetSubGraphInput() {
sub_param_input->set_default_param(nullptr);
}
}
return true;
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -36,9 +36,9 @@ class InferShapePass : public Pass {
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
STATUS InferProcess(const FuncGraphPtr &func_graph);
STATUS SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
STATUS SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
STATUS SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
bool ResetSubGraphInput();
int ResetSubGraphInput();
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};

View File

@ -90,16 +90,16 @@ std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, Format
return cur_axes;
}
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
int TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
const std::vector<int> &axes, FormatTransNodeType trans_type,
NodeInferShape *node_infer_shape) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
if (input_index >= cnode->size() || axes.empty()) {
return;
return lite::RET_ERROR;
}
auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index);
if (origin_input.size() != axes.size()) {
return;
return lite::RET_ERROR;
}
std::vector<int> cur_input;
for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
@ -119,7 +119,9 @@ void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
}
}
auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "BuildIntVecParameterNode failed");
func_graph->manager()->Replace(cnode->input(input_index), param_node);
return lite::RET_OK;
}
STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
@ -266,6 +268,7 @@ STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Form
}
int element_num = shape.front();
auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0));
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
std::vector<int> axes;
if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
for (int index = 0; index < element_num; ++index) {
@ -314,6 +317,7 @@ STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
auto param_node =
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
return lite::RET_OK;
}

View File

@ -72,6 +72,7 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
}
if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
auto transpose_cnode = node->cast<CNodePtr>();
MS_ASSERT(transpose_cnode != nullptr);
if (!CheckPrimitiveType(transpose_cnode->input(kTransposeInput), prim::kPrimConv2DFusion)) {
continue;
}
@ -90,6 +91,7 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
continue;
}
auto transpose_cnode = conv_node->input(kTransposeInput)->cast<CNodePtr>();
MS_ASSERT(transpose_cnode != nullptr);
auto perm = GetTransposePerm(transpose_cnode);
if (perm == kPermNHWC) {
manager->Replace(transpose_cnode, transpose_cnode->input(1));