forked from mindspore-Ecosystem/mindspore
fix tool converter
This commit is contained in:
parent
ad88cfcafb
commit
04d45a701c
|
@ -55,12 +55,13 @@ constexpr int kMaxDepth = 2048;
|
|||
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
||||
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
||||
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
std::vector<AnfNodePtr> vecs{};
|
||||
if (node == nullptr) {
|
||||
return vecs;
|
||||
}
|
||||
if (node->isa<mindspore::CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto &inputs = cnode->inputs();
|
||||
// Check if free variables used.
|
||||
for (const auto &input : inputs) {
|
||||
|
@ -78,7 +79,7 @@ std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
|||
return vecs;
|
||||
};
|
||||
|
||||
std::list<CNodePtr> cnodes;
|
||||
std::list<CNodePtr> cnodes{};
|
||||
auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
|
||||
for (const auto &node : nodes) {
|
||||
auto cnode = dyn_cast<mindspore::CNode>(node);
|
||||
|
@ -100,10 +101,7 @@ int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptr<schema::Meta
|
|||
first_tensor_output->dataType = kNumberTypeInt8;
|
||||
} else {
|
||||
auto primc = primitive->cast<std::shared_ptr<mindspore::ops::QuantDTypeCast>>();
|
||||
if (primc == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(primc != nullptr, RET_ERROR, "cast ptr failed");
|
||||
if (primc->get_dst_t() != kNumberTypeFloat32) {
|
||||
first_tensor_output->dataType = kNumberTypeInt8;
|
||||
}
|
||||
|
@ -220,6 +218,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
|
|||
const AnfNodePtr &input) {
|
||||
lite::DataInfo data_info;
|
||||
auto param_node = input->cast<ParameterPtr>();
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_NULL_PTR, "cast ptr failed");
|
||||
if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "FetchFromDefaultParam failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -242,7 +241,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
|
|||
int AnfExporter::SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const size_t &subgraph_index) {
|
||||
auto &subgraph = meta_graphT->subGraph.at(subgraph_index);
|
||||
FuncGraphPtr fg;
|
||||
FuncGraphPtr fg = nullptr;
|
||||
std::for_each(fg_subgraph_map_.begin(), fg_subgraph_map_.end(),
|
||||
[&subgraph_index, &fg](const std::pair<const FuncGraphPtr, size_t> &it) {
|
||||
if (it.second == subgraph_index) {
|
||||
|
@ -317,6 +316,7 @@ int AnfExporter::ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
const bool ©_primitive, const CNodePtr &partial_cnode,
|
||||
const std::unique_ptr<schema::CNodeT> &schema_cnode) {
|
||||
auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(partial_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, RET_NULL_PTR, "GetValueNode failed");
|
||||
if (prim->name() != mindspore::ops::kNamePartialFusion) {
|
||||
MS_LOG(INFO) << "not is partial";
|
||||
return RET_OK;
|
||||
|
@ -324,13 +324,10 @@ int AnfExporter::ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
|
||||
auto partial_fusion_primc = schema_cnode->primitive->value.AsPartialFusion();
|
||||
auto vnode = partial_cnode->input(kFirstDataIndex)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(vnode != nullptr);
|
||||
MS_CHECK_TRUE_MSG(partial_fusion_primc != nullptr, RET_NULL_PTR, "partial_fusion_primc is invalid");
|
||||
MS_CHECK_TRUE_MSG(vnode != nullptr, RET_NULL_PTR, "vnode is invalid");
|
||||
auto fg = vnode->value()->cast<FuncGraphPtr>();
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "func graph is nullptr.");
|
||||
if (fg_subgraph_map_.find(fg) != fg_subgraph_map_.end()) {
|
||||
partial_fusion_primc->sub_graph_index = fg_subgraph_map_.at(fg);
|
||||
return RET_OK;
|
||||
|
@ -376,10 +373,6 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|||
for (const auto &cnode : cnodes) {
|
||||
auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
|
||||
std::unique_ptr<schema::PrimitiveT> primT;
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "prim is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = RemoveIfDepend(cnode);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -513,10 +506,13 @@ FuncGraphPtr GetFinalGraph(const FuncGraphPtr &func_graph, int i) {
|
|||
auto cnode = call_cnode->input(kFirstDataIndex)->cast<CNodePtr>();
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
auto false_cnode = cnode->input(kSwitchFalseIndex)->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(false_cnode != nullptr, nullptr, "cast failed");
|
||||
auto false_fg = GetValueNode<FuncGraphPtr>(false_cnode->input(kFirstDataIndex));
|
||||
MS_CHECK_TRUE_MSG(false_fg != nullptr, nullptr, "GetValueNode failed");
|
||||
return GetFinalGraph(false_fg, i);
|
||||
} else {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kFirstDataIndex));
|
||||
MS_CHECK_TRUE_MSG(fg != nullptr, nullptr, "GetValueNode failed");
|
||||
return GetFinalGraph(fg, i);
|
||||
}
|
||||
|
||||
|
@ -542,10 +538,7 @@ int AnfExporter::SetMetaGraphOutput(const FuncGraphPtr &func_graph,
|
|||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
|
||||
int i = 0;
|
||||
auto final_fg = GetFinalGraph(func_graph, i);
|
||||
if (final_fg == nullptr) {
|
||||
MS_LOG(ERROR) << "GetFinalGraph failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(final_fg != nullptr, RET_ERROR, "GetFinalGraph failed.");
|
||||
auto final_meta_graph_index = fg_subgraph_map_.at(final_fg);
|
||||
auto &final_meta_graph = meta_graphT->subGraph.at(final_meta_graph_index);
|
||||
meta_graphT->outputIndex.assign(final_meta_graph->outputIndices.begin(), final_meta_graph->outputIndices.end());
|
||||
|
@ -605,10 +598,7 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema
|
|||
}
|
||||
if (utils::isa<abstract::AbstractTuple>(input_anode->abstract())) {
|
||||
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input_anode->abstract());
|
||||
if (tuple == nullptr) {
|
||||
MS_LOG(ERROR) << "tuple is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
|
||||
auto elements = tuple->elements();
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
auto key = std::make_pair(input_anode, i);
|
||||
|
@ -627,6 +617,7 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema
|
|||
|
||||
int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) {
|
||||
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
||||
MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "cast ptr failed");
|
||||
auto input_value_node = input_cnode->input(kPrimIndex)->cast<ValueNodePtr>();
|
||||
if (input_value_node == nullptr) {
|
||||
if (!IsCall(input_cnode)) {
|
||||
|
@ -635,6 +626,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
|
|||
} else {
|
||||
auto call_anf_prim_vnode = GetCallAnfPrim();
|
||||
auto cnode_input = input_cnode->inputs();
|
||||
MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, RET_ERROR, "GetCallAnfPrim failed");
|
||||
cnode_input.insert(cnode_input.begin(), call_anf_prim_vnode);
|
||||
input_cnode->set_inputs(cnode_input);
|
||||
}
|
||||
|
@ -658,10 +650,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto value_node = utils::cast<ValueNodePtr>(index_vnode);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "cast to ValueNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "cast to ValueNode failed");
|
||||
auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value())
|
||||
: GetValue<int>(value_node->value());
|
||||
auto key = std::make_pair(get_item_input_cnode, idx);
|
||||
|
@ -794,10 +783,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
|
|||
|
||||
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
|
||||
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
|
||||
if (tuple == nullptr) {
|
||||
MS_LOG(ERROR) << "tuple is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
|
||||
auto elements = tuple->elements();
|
||||
for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) {
|
||||
auto ms_tensor = new (std::nothrow) schema::TensorT();
|
||||
|
@ -841,6 +827,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
|
|||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
|
||||
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
|
||||
ms_tensor->dataType = type_ptr->type_id();
|
||||
meta_graphT->allTensors.emplace_back(ms_tensor);
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
|
@ -881,8 +868,10 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
|
|||
|
||||
CNodePtr AnfExporter::CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
|
||||
auto call_anf_prim_vnode = GetCallAnfPrim();
|
||||
MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, nullptr, "GetCallAnfPrim failed");
|
||||
std::vector<AnfNodePtr> inputs{call_anf_prim_vnode, node};
|
||||
auto cnode = fg->NewCNodeInOrder(inputs);
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "NewCNode failed");
|
||||
cnode->set_func_graph(fg);
|
||||
return cnode;
|
||||
}
|
||||
|
@ -890,19 +879,23 @@ CNodePtr AnfExporter::CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &
|
|||
CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cast ptr failed");
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(kPrimIndex));
|
||||
if (primitive_c != nullptr) {
|
||||
return cnode;
|
||||
}
|
||||
auto partial_anf_prim_vnode = GetPartialFusionPrim();
|
||||
auto cnode_input = cnode->inputs();
|
||||
MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
|
||||
cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode);
|
||||
cnode->set_inputs(cnode_input);
|
||||
return cnode;
|
||||
} else if (utils::isa<ValueNodePtr>(node)) {
|
||||
auto partial_anf_prim_vnode = GetPartialFusionPrim();
|
||||
MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
|
||||
std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node};
|
||||
auto cnode = fg->NewCNode(inputs);
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "New cnode failed");
|
||||
return cnode;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "failed to create partial cnode.";
|
||||
|
|
|
@ -91,6 +91,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh
|
|||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
||||
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
|
||||
MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
|
||||
*data_type = typePtr->type_id();
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||
|
@ -105,12 +106,14 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
|
|||
bool train_flag, DataInfo *data_info) {
|
||||
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
|
||||
auto valueAbstract = value_node->abstract();
|
||||
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
|
||||
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
||||
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
||||
MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr");
|
||||
data_info->data_type_ = typePtr->type_id();
|
||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||
|
@ -121,6 +124,7 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
|
|||
auto value = value_node->value();
|
||||
MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
|
||||
data_info->data_.resize(data->Size());
|
||||
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
|
||||
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
|
||||
|
@ -159,6 +163,7 @@ int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &pr
|
|||
auto value = value_node->value();
|
||||
MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
|
||||
auto data = value->cast<mindspore::BoolImmPtr>();
|
||||
MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr");
|
||||
auto data_value = data->value();
|
||||
if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
|
@ -173,6 +178,7 @@ int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
|
|||
data_info->shape_ = {1};
|
||||
data_info->data_.resize(sizeof(int));
|
||||
auto data = value_node->value()->cast<NumberPtr>();
|
||||
MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberPtr failed");
|
||||
int number_type = data->number_type();
|
||||
if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) {
|
||||
number_type = TypeToTypeMap.at(number_type);
|
||||
|
@ -256,16 +262,13 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
DataInfo *data_info) {
|
||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||
auto param_node = cnode->input(index)->cast<ParameterPtr>();
|
||||
if (param_node == nullptr) {
|
||||
MS_LOG(ERROR) << "input node is not parameter node.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
|
||||
if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch information from default param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
|
||||
if (prim->GetAttr(ops::kFormat) == nullptr && !param_node->has_default()) {
|
||||
data_info->format_ = mindspore::NHWC;
|
||||
}
|
||||
|
@ -289,10 +292,8 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
|||
DataInfo *data_info) {
|
||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||
auto value_node = cnode->input(index)->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "input node is not value node.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
|
||||
|
||||
auto value = value_node->value();
|
||||
int ret = RET_OK;
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -328,6 +329,7 @@ int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fm
|
|||
DataInfo *data_info) {
|
||||
data_info->format_ = mindspore::NHWC;
|
||||
auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
|
||||
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed");
|
||||
if (input_node_prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
auto value = input_node_prim->GetAttr(ops::kFormat);
|
||||
if (value->isa<mindspore::Int64Imm>()) {
|
||||
|
@ -367,6 +369,7 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
|
||||
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed");
|
||||
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||
|
@ -390,6 +393,7 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
|
||||
MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed");
|
||||
if (tensor_value->Size() >= kTensorListMinSize) {
|
||||
data_info->data_.resize(tensor_value->Size());
|
||||
if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
|
||||
|
@ -410,21 +414,21 @@ int RemoveIfDepend(const CNodePtr &cnode) {
|
|||
inputs.emplace_back(cnode->input(0));
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
AnfNodePtr inputNode = cnode->input(i);
|
||||
MS_CHECK_TRUE_MSG(inputNode != nullptr, RET_NULL_PTR, "inputNode is nullptr");
|
||||
if (!inputNode->isa<CNode>()) {
|
||||
inputs.emplace_back(cnode->input(i));
|
||||
continue;
|
||||
}
|
||||
auto depend_node = utils::cast<CNodePtr>(inputNode);
|
||||
MS_CHECK_TRUE_MSG(depend_node != nullptr, RET_NULL_PTR, "depend_node is nullptr");
|
||||
auto value_node = depend_node->input(0)->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "value node is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
|
||||
if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
|
||||
has_depend = true;
|
||||
bool mask_out = (depend_node->inputs().size() == opt::kInputSizeThree);
|
||||
for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
|
||||
AnfNodePtr depend_input_node = depend_node->input(j);
|
||||
MS_CHECK_TRUE_MSG(depend_input_node != nullptr, RET_NULL_PTR, "depend_input_node is nullptr");
|
||||
if (depend_input_node->isa<CNode>()) {
|
||||
inputs.emplace_back(depend_input_node);
|
||||
if (mask_out) {
|
||||
|
@ -450,16 +454,15 @@ int RemoveIfMakeTuple(const CNodePtr &cnode) {
|
|||
inputs.emplace_back(cnode->input(0));
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
AnfNodePtr input_node = cnode->input(i);
|
||||
MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "input_node is nullptr");
|
||||
if (!input_node->isa<CNode>()) {
|
||||
inputs.emplace_back(cnode->input(i));
|
||||
continue;
|
||||
}
|
||||
auto make_tuple_node = utils::cast<CNodePtr>(input_node);
|
||||
MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, RET_NULL_PTR, "make_tuple_node is nullptr");
|
||||
auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "value node is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid.");
|
||||
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) ||
|
||||
opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) {
|
||||
has_make_tuple = true;
|
||||
|
|
|
@ -30,7 +30,7 @@ Option<std::string> FlagParser::ParseFlags(int argc, const char *const *argv, bo
|
|||
}
|
||||
binName = GetFileName(argv[0]);
|
||||
|
||||
std::multimap<std::string, Option<std::string>> keyValues;
|
||||
std::multimap<std::string, Option<std::string>> keyValues{};
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string tmp = argv[i];
|
||||
Trim(&tmp);
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
|
@ -29,18 +28,43 @@
|
|||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
SubGraph::SubGraph(FuncGraphPtr belong_anf, std::string graph_name, const std::set<CNodePtr> &head_nodes)
|
||||
: belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {
|
||||
InitSubGraphNode(head_nodes);
|
||||
InitSubGraphInNode();
|
||||
InitSubGraphOutNode();
|
||||
int SubGraph::Init(const std::set<CNodePtr> &head_nodes) {
|
||||
auto ret = InitSubGraphNode(head_nodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitSubGraphNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = InitSubGraphInNode();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitSubGraphInNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = InitSubGraphOutNode();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitSubGraphOutNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
|
||||
int SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
|
||||
this->nodes_ = nodes;
|
||||
InitSubGraphNode(head_nodes);
|
||||
InitSubGraphInNode();
|
||||
InitSubGraphOutNode();
|
||||
auto ret = InitSubGraphNode(head_nodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitSubGraphNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = InitSubGraphInNode();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitSubGraphInNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = InitSubGraphOutNode();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitSubGraphOutNode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::set<CNodePtr> SubGraph::GetNodes() const { return this->nodes_; }
|
||||
|
@ -125,11 +149,11 @@ std::set<CNodePtr> SubGraph::GetOutputCNodes() const {
|
|||
return outputs;
|
||||
}
|
||||
|
||||
void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
||||
int SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
||||
MS_ASSERT(belong_anf_ != nullptr);
|
||||
MS_ASSERT(belong_anf_->manager() != nullptr);
|
||||
auto node_users = belong_anf_->manager()->node_users();
|
||||
std::queue<CNodePtr> q;
|
||||
std::queue<CNodePtr> q{};
|
||||
for (const auto &head_node : head_nodes) {
|
||||
if (head_node == nullptr) {
|
||||
continue;
|
||||
|
@ -138,7 +162,7 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
|||
}
|
||||
while (!q.empty()) {
|
||||
auto cur_node = q.front();
|
||||
MS_CHECK_PTR_IF_NULL(cur_node);
|
||||
MS_CHECK_TRUE_MSG(cur_node != nullptr, RET_NULL_PTR, "cur_node is nullptr");
|
||||
q.pop();
|
||||
this->nodes_.insert(cur_node);
|
||||
// check output-cnode of cur-node only depend on cur-node
|
||||
|
@ -153,11 +177,12 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
|||
continue;
|
||||
}
|
||||
auto post_cnode = utils::cast<CNodePtr>(post_node);
|
||||
MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "cast failed");
|
||||
// return-node should not be include into subgraph absolutely // ut
|
||||
if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) {
|
||||
continue;
|
||||
}
|
||||
MS_CHECK_PTR_IF_NULL(post_cnode);
|
||||
MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr");
|
||||
bool non_depend = true;
|
||||
// check all inputs of output-cnode
|
||||
for (const auto &input : post_cnode->inputs()) {
|
||||
|
@ -167,6 +192,7 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
|||
// input cnode is not contained in subgraph
|
||||
if (utils::isa<CNodePtr>(input)) {
|
||||
auto input_cnode = utils::cast<CNodePtr>(input);
|
||||
MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
|
||||
if (this->nodes_.count(input_cnode) == 0) {
|
||||
non_depend = false;
|
||||
break;
|
||||
|
@ -175,6 +201,7 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
|||
// input parameter is a graph input
|
||||
if (utils::isa<ParameterPtr>(input)) {
|
||||
auto input_parameter = utils::cast<ParameterPtr>(input);
|
||||
MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_NULL_PTR, "cast failed");
|
||||
if (!input_parameter->has_default()) {
|
||||
non_depend = false;
|
||||
break;
|
||||
|
@ -186,9 +213,10 @@ void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
|||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SubGraph::InitSubGraphInNode() {
|
||||
int SubGraph::InitSubGraphInNode() {
|
||||
MS_ASSERT(belong_anf_ != nullptr);
|
||||
MS_ASSERT(belong_anf_->manager() != nullptr);
|
||||
auto node_users = belong_anf_->manager()->node_users();
|
||||
|
@ -203,6 +231,7 @@ void SubGraph::InitSubGraphInNode() {
|
|||
}
|
||||
if (utils::isa<CNodePtr>(input)) {
|
||||
auto input_cnode = utils::cast<CNodePtr>(input);
|
||||
MS_CHECK_TRUE_MSG(input_cnode != nullptr, false, "cast failed");
|
||||
if (this->nodes_.count(input_cnode) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
@ -210,6 +239,7 @@ void SubGraph::InitSubGraphInNode() {
|
|||
// graph input or shared weight input // ut
|
||||
if (utils::isa<ParameterPtr>(input)) {
|
||||
auto input_parameter = utils::cast<ParameterPtr>(input);
|
||||
MS_CHECK_TRUE_MSG(input_parameter != nullptr, false, "cast failed");
|
||||
if (!input_parameter->has_default()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -223,9 +253,10 @@ void SubGraph::InitSubGraphInNode() {
|
|||
in_nodes_.insert(node);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SubGraph::InitSubGraphOutNode() {
|
||||
int SubGraph::InitSubGraphOutNode() {
|
||||
MS_ASSERT(belong_anf_ != nullptr);
|
||||
MS_ASSERT(belong_anf_->manager() != nullptr);
|
||||
auto node_users = belong_anf_->manager()->node_users();
|
||||
|
@ -250,6 +281,7 @@ void SubGraph::InitSubGraphOutNode() {
|
|||
return true;
|
||||
}
|
||||
auto output_cnode = utils::cast<CNodePtr>(output_node);
|
||||
MS_CHECK_TRUE_MSG(output_cnode != nullptr, false, "cast failed");
|
||||
if (this->nodes_.count(output_cnode) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
@ -258,6 +290,7 @@ void SubGraph::InitSubGraphOutNode() {
|
|||
continue;
|
||||
out_nodes_.insert(node);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
|
||||
|
@ -272,7 +305,10 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
|
|||
auto new_nodes2 = subgraph->GetNodes();
|
||||
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
|
||||
new_nodes.insert(common_outputs.begin(), common_outputs.end());
|
||||
this->Reset(new_nodes, common_outputs);
|
||||
if (this->Reset(new_nodes, common_outputs) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reset failed";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -280,7 +316,10 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
|
|||
auto new_nodes = this->GetNodes();
|
||||
auto new_nodes2 = subgraph->GetNodes();
|
||||
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
|
||||
this->Reset(new_nodes);
|
||||
if (this->Reset(new_nodes) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reset failed";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -290,8 +329,8 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
|
|||
SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
|
||||
MS_ASSERT(belong_anf_ == nullptr);
|
||||
// find before subgraph's nodes
|
||||
std::queue<CNodePtr> q;
|
||||
std::set<CNodePtr> before_nodes;
|
||||
std::queue<CNodePtr> q{};
|
||||
std::set<CNodePtr> before_nodes{};
|
||||
for (const auto &node : this->GetInCNodes()) {
|
||||
for (const auto &in_cnode : lite::GetInputCNode(node)) {
|
||||
if (in_cnode == nullptr) {
|
||||
|
@ -311,7 +350,11 @@ SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
|
|||
}
|
||||
// construct before subgraph
|
||||
auto before_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/before_subgraph");
|
||||
before_subgraph->Reset(before_nodes);
|
||||
MS_CHECK_TRUE_MSG(before_subgraph != nullptr, nullptr, "before_subgraph is nullptr");
|
||||
if (before_subgraph->Reset(before_nodes) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reset failed";
|
||||
return nullptr;
|
||||
}
|
||||
return before_subgraph;
|
||||
}
|
||||
|
||||
|
@ -325,8 +368,8 @@ SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
|
|||
return nullptr;
|
||||
}
|
||||
// find after subgraph's nodes
|
||||
std::queue<CNodePtr> q;
|
||||
std::set<CNodePtr> after_nodes;
|
||||
std::queue<CNodePtr> q{};
|
||||
std::set<CNodePtr> after_nodes{};
|
||||
auto output_node = belong_anf_->output();
|
||||
if (!utils::isa<CNodePtr>(output_node)) {
|
||||
MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope();
|
||||
|
@ -348,7 +391,12 @@ SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
|
|||
}
|
||||
// construct before subgraph
|
||||
auto after_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/after_subgraph");
|
||||
after_subgraph->Reset(after_nodes);
|
||||
MS_CHECK_TRUE_MSG(after_subgraph != nullptr, nullptr, "after_subgraph is nullptr");
|
||||
if (after_subgraph->Reset(after_nodes) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reset failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return after_subgraph;
|
||||
}
|
||||
|
||||
|
@ -366,6 +414,7 @@ int SubGraph::CreatePartialInBelongAnf() {
|
|||
}
|
||||
// create func_graph of partial
|
||||
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
||||
MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
|
||||
auto manager = belong_anf_->manager();
|
||||
manager->AddFuncGraph(func_graph);
|
||||
func_graph->set_attr("graph_name", MakeValue(graph_name));
|
||||
|
@ -373,12 +422,20 @@ int SubGraph::CreatePartialInBelongAnf() {
|
|||
// create cnode and parameter for func_graph of partial
|
||||
std::vector<AnfNodePtr> partial_inputs;
|
||||
std::map<AnfNodePtr, AnfNodePtr> partial_inputs_and_subgraph_input_map;
|
||||
CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
|
||||
CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
|
||||
auto ret = CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CreateParameterForPartialSubGraph failed";
|
||||
return ret;
|
||||
}
|
||||
ret = CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CreateCNodeForPartialSubGraph failed";
|
||||
return ret;
|
||||
}
|
||||
// add return for func_graph of partial
|
||||
auto sub_graph_outputs = this->GetOutCNodes();
|
||||
MS_ASSERT(!sub_graph_outputs.empty());
|
||||
auto ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
|
||||
ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Set subgraph output failed";
|
||||
return ret;
|
||||
|
@ -386,8 +443,11 @@ int SubGraph::CreatePartialInBelongAnf() {
|
|||
// create partial cnode
|
||||
auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
|
||||
auto graph_value_node = NewValueNode(func_graph);
|
||||
MS_CHECK_TRUE_MSG(partial_prim != nullptr, RET_NULL_PTR, "partial_prim is nullptr");
|
||||
MS_CHECK_TRUE_MSG(graph_value_node != nullptr, RET_NULL_PTR, "graph_value_node is nullptr");
|
||||
partial_inputs.insert(partial_inputs.begin(), graph_value_node);
|
||||
auto partial_cnode = belong_anf_->NewCNode(partial_prim, partial_inputs);
|
||||
MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "partial_cnode is nullptr");
|
||||
partial_cnode->set_fullname_with_scope(graph_name + "/partial");
|
||||
for (size_t i = 0; i < partial_inputs.size(); ++i) {
|
||||
const auto &input = partial_inputs.at(i);
|
||||
|
@ -396,6 +456,7 @@ int SubGraph::CreatePartialInBelongAnf() {
|
|||
// create call cnode
|
||||
std::vector<AnfNodePtr> call_node_inputs{partial_cnode};
|
||||
auto call_cnode = belong_anf_->NewCNode(call_node_inputs);
|
||||
MS_CHECK_TRUE_MSG(call_cnode != nullptr, RET_NULL_PTR, "call_cnode is nullptr");
|
||||
call_cnode->set_fullname_with_scope(graph_name + "/call");
|
||||
// replace belong-graph's output
|
||||
auto return_node = belong_anf_->get_return();
|
||||
|
@ -412,7 +473,7 @@ int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNode
|
|||
return lite::SetFuncGraphOutput(graph, output_nodes);
|
||||
}
|
||||
|
||||
void SubGraph::CreateParameterForPartialSubGraph(
|
||||
int SubGraph::CreateParameterForPartialSubGraph(
|
||||
const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
|
||||
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map) {
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
|
@ -436,6 +497,7 @@ void SubGraph::CreateParameterForPartialSubGraph(
|
|||
// create subgraph input parameter from cnode and record partial inputs
|
||||
if (utils::isa<CNodePtr>(input)) {
|
||||
auto input_cnode = utils::cast<CNodePtr>(input);
|
||||
MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
|
||||
if (this->GetNodes().count(input_cnode) > 0) {
|
||||
continue;
|
||||
}
|
||||
|
@ -450,6 +512,7 @@ void SubGraph::CreateParameterForPartialSubGraph(
|
|||
auto node_users = this->belong_anf_->manager()->node_users();
|
||||
if (utils::isa<ParameterPtr>(input)) {
|
||||
auto parameter = utils::cast<ParameterPtr>(input);
|
||||
MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "cast ptr failed");
|
||||
// graph input: create a parameter
|
||||
if (!parameter->has_default()) {
|
||||
auto new_parameter = sub_graph->add_parameter();
|
||||
|
@ -475,9 +538,10 @@ void SubGraph::CreateParameterForPartialSubGraph(
|
|||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SubGraph::CreateCNodeForPartialSubGraph(
|
||||
int SubGraph::CreateCNodeForPartialSubGraph(
|
||||
const FuncGraphPtr &sub_graph, const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map) {
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
// move cnode from belong_graph to subgraph
|
||||
|
@ -500,6 +564,7 @@ void SubGraph::CreateCNodeForPartialSubGraph(
|
|||
}
|
||||
this->belong_anf_->DropNode(node);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SubGraph::ApplySubGraph() {
|
||||
|
@ -546,7 +611,10 @@ int SubGraph::ApplySubGraph() {
|
|||
MS_ASSERT(after_partial_cnode != nullptr);
|
||||
subgraph_nodes.insert(after_partial_cnode);
|
||||
subgraph_nodes.insert(call_cnode);
|
||||
this->Reset(subgraph_nodes);
|
||||
if (this->Reset(subgraph_nodes) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reset failed";
|
||||
return false;
|
||||
}
|
||||
// create subgraph partial // add partial to main subgraph
|
||||
ret = this->CreatePartialInBelongAnf();
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -32,9 +33,11 @@ class SubGraph;
|
|||
using SubGraphPtr = std::shared_ptr<SubGraph>;
|
||||
class SubGraph {
|
||||
public:
|
||||
explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "", const std::set<CNodePtr> &head_nodes = {});
|
||||
explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "")
|
||||
: belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {}
|
||||
|
||||
void Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {});
|
||||
int Init(const std::set<CNodePtr> &head_nodes = {});
|
||||
int Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {});
|
||||
|
||||
bool MergeSubGraph(const SubGraphPtr &subgraph);
|
||||
|
||||
|
@ -48,19 +51,19 @@ class SubGraph {
|
|||
std::set<CNodePtr> GetInputCNodes() const;
|
||||
std::set<CNodePtr> GetOutputCNodes() const;
|
||||
// init subgraph methods
|
||||
void InitSubGraphNode(const std::set<CNodePtr> &head_nodes);
|
||||
void InitSubGraphInNode();
|
||||
void InitSubGraphOutNode();
|
||||
int InitSubGraphNode(const std::set<CNodePtr> &head_nodes);
|
||||
int InitSubGraphInNode();
|
||||
int InitSubGraphOutNode();
|
||||
// merge subgraph methods
|
||||
std::set<CNodePtr> FindCommonOutputs(const SubGraphPtr &subgraph) const;
|
||||
bool IfDependOnSameNode(const SubGraphPtr &subgraph) const;
|
||||
// apply subgraph methods
|
||||
SubGraphPtr FindBeforeSubGraphInBelongAnf() const;
|
||||
SubGraphPtr FindAfterSubGraphInBelongAnf() const;
|
||||
void CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
|
||||
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map);
|
||||
void CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph,
|
||||
const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map);
|
||||
int CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
|
||||
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map);
|
||||
int CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph,
|
||||
const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map);
|
||||
int CreatePartialInBelongAnf();
|
||||
static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs);
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mindspore/core/ops/switch.h"
|
||||
#include "mindspore/core/ops/call.h"
|
||||
#include "mindspore/core/ops/fusion/partial_fusion.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -332,6 +333,7 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
|
|||
size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
|
||||
MS_ASSERT(anf_node != nullptr);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (train_flag &&
|
||||
(opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
|
||||
return 1;
|
||||
|
@ -350,6 +352,7 @@ bool IsPartialFusion(const AnfNodePtr &node) {
|
|||
}
|
||||
if (node->isa<mindspore::CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
|
||||
auto vnode_value = cnode->input(0)->cast<ValueNodePtr>()->value();
|
||||
return GetValue<NamedPtr>(vnode_value)->name() == "PartialFusion";
|
||||
}
|
||||
|
@ -364,6 +367,7 @@ bool IsCall(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed");
|
||||
if (cnode->inputs().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -400,19 +404,25 @@ bool IsMakeTuple(const AnfNodePtr &node) {
|
|||
|
||||
ValueNodePtr GetPartialFusionPrim() {
|
||||
auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
|
||||
MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr");
|
||||
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
|
||||
MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr");
|
||||
return partial_anf_prim;
|
||||
}
|
||||
|
||||
ValueNodePtr GetSwitchAnfPrim() {
|
||||
auto switch_prim = std::make_shared<mindspore::ops::Switch>();
|
||||
MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
|
||||
ValueNodePtr switch_anf_prim = NewValueNode(switch_prim);
|
||||
MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr");
|
||||
return switch_anf_prim;
|
||||
}
|
||||
|
||||
ValueNodePtr GetCallAnfPrim() {
|
||||
auto call_prim = std::make_shared<mindspore::ops::Call>();
|
||||
MS_CHECK_TRUE_MSG(call_prim != nullptr, nullptr, "call_prim is nullptr");
|
||||
ValueNodePtr call_anf_prim = NewValueNode(call_prim);
|
||||
MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
|
||||
return call_anf_prim;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mindspore/lite/tools/common/string_util.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int ReadFileToIfstream(const std::string &file_path, std::ifstream *ifstream) {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor) {
|
||||
|
@ -131,6 +132,7 @@ std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::Tenso
|
|||
return nullptr;
|
||||
}
|
||||
auto schema_tensor = std::make_unique<schema::TensorT>();
|
||||
MS_CHECK_TRUE_MSG(schema_tensor != nullptr, nullptr, "schema_tensor is nullptr");
|
||||
schema_tensor->name = tensor_name;
|
||||
auto ret = UpdateTensorTFromTensorInfo(tensor_info, &schema_tensor);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -37,6 +37,7 @@ static bool IsSupportedNode(const BaseRef &n) {
|
|||
};
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||
MS_ASSERT(anf_node != nullptr);
|
||||
return std::any_of(support_list.begin(), support_list.end(),
|
||||
[&anf_node](const auto &primitive) { return CheckPrimitiveType(anf_node, primitive); });
|
||||
}
|
||||
|
@ -95,6 +96,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
|
|||
}
|
||||
auto new_return_node = func_graph->NewCNode({NewValueNode(return_prim_ptr), make_tuple_cnode});
|
||||
new_return_node->set_fullname_with_scope(return_cnode->fullname_with_scope());
|
||||
MS_ASSERT(new_return_node != nullptr);
|
||||
func_graph->set_return(new_return_node);
|
||||
MS_ASSERT(new_return_node != nullptr);
|
||||
|
||||
|
@ -103,7 +105,9 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
|
|||
|
||||
const BaseRef AddTensorArray::DefinePattern() const {
|
||||
auto support_detect = std::make_shared<CondVar>(IsSupportedNode);
|
||||
MS_ASSERT(support_detect != nullptr);
|
||||
auto inputs_var = std::make_shared<SeqVar>();
|
||||
MS_ASSERT(inputs_var != nullptr);
|
||||
return VectorRef({support_detect, inputs_var});
|
||||
}
|
||||
|
||||
|
@ -137,6 +141,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
|
|||
}
|
||||
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
|
||||
MS_ASSERT(abstract_tensor != nullptr);
|
||||
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
|
||||
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
|
||||
return nullptr;
|
||||
|
@ -168,6 +173,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
|
|||
NewValueNode(tensor_array),
|
||||
NewValueNode(kDefaultNumTensors),
|
||||
});
|
||||
MS_ASSERT(tensor_array_node != nullptr);
|
||||
tensor_array_node->set_abstract(abstract->Clone());
|
||||
tensor_array_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_tensor_array");
|
||||
|
||||
|
|
|
@ -168,6 +168,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow
|
|||
}
|
||||
visited_nodes->insert(node);
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cast ptr failed");
|
||||
for (size_t i = 0; i < cnode->inputs().size(); i++) {
|
||||
auto input = cnode->input(i);
|
||||
if (visited_nodes->find(input) == visited_nodes->end()) {
|
||||
|
@ -191,6 +192,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow
|
|||
int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
|
||||
const CNodePtr &aim_cnode, FuncGraphPtr *after_fg) {
|
||||
*after_fg = std::make_shared<FuncGraph>();
|
||||
MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
|
||||
auto manager = main_fg->manager();
|
||||
manager->AddFuncGraph(*after_fg);
|
||||
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
|
||||
|
@ -282,6 +284,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
|
|||
int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
|
||||
CNodePtr *body_partial_node) {
|
||||
auto body_vnode = while_cnode->input(kWhileBodyIndex);
|
||||
MS_CHECK_TRUE_MSG(body_vnode != nullptr, lite::RET_NULL_PTR, "body_vnode is nullptr");
|
||||
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
|
||||
if (body_fg == nullptr) {
|
||||
MS_LOG(ERROR) << "Get value as func_graph failed.";
|
||||
|
@ -469,13 +472,14 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
|
|||
ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
|
||||
if (switch_anf_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
|
||||
return false;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
// insert switch node
|
||||
std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
|
||||
after_partial_cnode};
|
||||
auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
|
||||
MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_ERROR, "NewCnode failed");
|
||||
switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString());
|
||||
|
||||
// insert call node
|
||||
|
@ -590,6 +594,7 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &i
|
|||
after_partial_cnode_inputs.push_back(then_fg->output());
|
||||
} else {
|
||||
auto then_fg_output = then_fg->output()->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(then_fg_output != nullptr, RET_ERROR, "cast ptr failed");
|
||||
for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->inputs().size(); ++i) {
|
||||
after_partial_cnode_inputs.push_back(then_fg_output->input(i));
|
||||
}
|
||||
|
|
|
@ -84,6 +84,7 @@ STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNode
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
auto cur_node_post = cur_node_user.first->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(cur_node_post != nullptr, RET_ERROR, "cast ptr failed");
|
||||
if (middle_nodes->find(cur_node_post) != middle_nodes->end() ||
|
||||
out_nodes->find(cur_node_post) != out_nodes->end()) {
|
||||
continue;
|
||||
|
@ -138,6 +139,7 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<
|
|||
continue;
|
||||
}
|
||||
auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(middle_node_prim != nullptr, false, "GetValueNode failed");
|
||||
if (!transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -145,30 +147,30 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<
|
|||
return true;
|
||||
}
|
||||
|
||||
void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
|
||||
bool train_flag, FormatTransNodeType trans_type) {
|
||||
int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
|
||||
bool train_flag, FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
return;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
lite::DataInfo data_info;
|
||||
int status;
|
||||
int status = 0;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
auto input_node = cnode->input(index)->cast<ParameterPtr>();
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
MS_CHECK_TRUE_MSG(input_node != nullptr, lite::RET_ERROR, "input_node is nullptr");
|
||||
if (!input_node->has_default()) {
|
||||
return;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
|
||||
} else {
|
||||
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
return;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (data_info.shape_.empty() ||
|
||||
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
|
||||
return;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end());
|
||||
if (data_info.shape_.size() == 1) {
|
||||
|
@ -180,22 +182,24 @@ void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode
|
|||
}
|
||||
auto tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), expand_shape,
|
||||
data_info.data_.data(), data_info.data_.size());
|
||||
MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr");
|
||||
if (trans_type == kNHWC2NCHW) {
|
||||
(void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW);
|
||||
} else {
|
||||
(void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC);
|
||||
}
|
||||
auto param_node = func_graph->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, , "add_parameter failed");
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add_parameter failed");
|
||||
param_node->set_name(cnode->input(index)->fullname_with_scope());
|
||||
status = lite::InitParameterFromTensorInfo(param_node, tensor);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "init parameter from tensor info failed";
|
||||
return;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto tr = func_graph->manager()->Transact();
|
||||
tr.SetEdge(cnode, index, param_node);
|
||||
tr.Commit();
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -273,6 +277,7 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
|
|||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
MS_CHECK_TRUE_MSG(prim_node != nullptr, lite::RET_ERROR, "prim_node is nullptr");
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
|
||||
auto &specify_nhwc_op_map = GetNHWCOpMap();
|
||||
|
@ -346,7 +351,11 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
ModifyCNodeFormat(cnode, trans_insert_info->pre_);
|
||||
status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = node_infer_shape_.InferShape(cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
|
@ -439,14 +448,22 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
|
|||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < middle_cnode->size(); ++i) {
|
||||
ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
|
||||
status = ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "change op attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ModifyCNodeFormat(middle_cnode, trans_info.post_);
|
||||
status = ModifyCNodeFormat(middle_cnode, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = node_infer_shape_.InferShape(middle_cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
|
@ -456,7 +473,7 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto sub_inputs = sub_graph->get_inputs();
|
||||
sub_inputs_map_[sub_graph] = sub_inputs;
|
||||
|
@ -475,6 +492,7 @@ void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGr
|
|||
MS_ASSERT(out_cnode != nullptr);
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
|
||||
if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
|
||||
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
|
||||
}
|
||||
|
@ -499,9 +517,10 @@ void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGr
|
|||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void DecreaseTransposeAlgo::ResetSubGraphInput() {
|
||||
int DecreaseTransposeAlgo::ResetSubGraphInput() {
|
||||
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||
auto &sub_graph = iter->first;
|
||||
auto &sub_inputs = iter->second;
|
||||
|
@ -509,7 +528,7 @@ void DecreaseTransposeAlgo::ResetSubGraphInput() {
|
|||
MS_ASSERT(manager != nullptr);
|
||||
for (auto &sub_input : sub_inputs) {
|
||||
auto param_node = sub_graph->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, , "add parameter failed");
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add parameter failed");
|
||||
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||
param_node->set_name(sub_input->fullname_with_scope());
|
||||
manager->Replace(sub_input, param_node);
|
||||
|
@ -518,9 +537,10 @@ void DecreaseTransposeAlgo::ResetSubGraphInput() {
|
|||
sub_param_input->set_default_param(nullptr);
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
int DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
MS_ASSERT(return_node != nullptr);
|
||||
|
@ -548,11 +568,13 @@ void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncG
|
|||
trans_cnode->set_fullname_with_scope(trans_input_name);
|
||||
}
|
||||
return_node->set_inputs(origin_input);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
int DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
MS_CHECK_TRUE_MSG(return_node != nullptr, lite::RET_ERROR, "return_node is nullptr");
|
||||
auto origin_inputs = return_node->inputs();
|
||||
lite::RemoveIfDepend(return_node);
|
||||
lite::RemoveIfMakeTuple(return_node);
|
||||
|
@ -560,7 +582,7 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
|
|||
bool infer_done = true;
|
||||
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||
auto abstract_base = GetCNodeInputAbstract(return_node, i);
|
||||
MS_CHECK_TRUE_MSG(abstract_base != nullptr, , "GetCNodeInputAbstract failed");
|
||||
MS_CHECK_TRUE_MSG(abstract_base != nullptr, lite::RET_ERROR, "GetCNodeInputAbstract failed");
|
||||
abstract_list.emplace_back(abstract_base->Clone());
|
||||
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
|
||||
MS_ASSERT(abstract_tensor != nullptr);
|
||||
|
@ -572,11 +594,12 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
|
|||
}
|
||||
if (utils::isa<CNodePtr>(return_node->input(i))) {
|
||||
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(input_cnode != nullptr, lite::RET_ERROR, "input_cnode is nullptr");
|
||||
if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
|
||||
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(input_prim != nullptr, , "GetValueNode failed");
|
||||
MS_CHECK_TRUE_MSG(input_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
|
||||
if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
|
||||
infer_done = false;
|
||||
}
|
||||
|
@ -592,22 +615,25 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun
|
|||
cnode->set_abstract(abstract_list.front());
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, , "GetValueNode Failed");
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
|
||||
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
|
||||
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
|
||||
int DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (pre_trans_type == kNONE) {
|
||||
return;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(primitive != nullptr, , "GetValueNode Failed");
|
||||
MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_ERROR, "GetValueNode Failed");
|
||||
if (pre_trans_type == kNHWC2NCHW) {
|
||||
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
|
||||
} else {
|
||||
primitive->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
|
||||
|
@ -618,7 +644,7 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun
|
|||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status;
|
||||
int status = 0;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
|
@ -634,18 +660,38 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun
|
|||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
auto ret = SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphInput failed";
|
||||
return false;
|
||||
}
|
||||
(void)DecreaseTransposeForSingleOp(sub_func_graph);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
ret = SetSubGraphOutput(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphOutput failed";
|
||||
return false;
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
ret = SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphInput failed";
|
||||
return false;
|
||||
}
|
||||
(void)DecreaseTransposeForSingleOp(sub_func_graph);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
ret = SetSubGraphOutput(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphOutput failed";
|
||||
return false;
|
||||
}
|
||||
ret = SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphAbstract failed";
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -726,7 +772,12 @@ bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
|
||||
auto ret = ResetSubGraphInput();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ResetSubGraphInput failed.";
|
||||
return false;
|
||||
}
|
||||
// if input format of several ops surrounded only by transpose op all can be NHWC,
|
||||
// we can delete these transpose ops, and at the same time, transform these middle ops.
|
||||
if (!DecreaseTransposeForMultiOp(func_graph)) {
|
||||
|
|
|
@ -51,11 +51,11 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
||||
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ResetSubGraphInput();
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
|
||||
int SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
int ResetSubGraphInput();
|
||||
int SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
int SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
int ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type);
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
|
|
|
@ -111,7 +111,11 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
return ResetSubGraphInput();
|
||||
if (ResetSubGraphInput() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ResetSubGraphInput failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) {
|
||||
|
@ -187,7 +191,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
if (SetSubGraphOutput(cnode, sub_func_graph)) {
|
||||
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
|
||||
return false;
|
||||
}
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
|
@ -202,7 +209,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
if (SetSubGraphOutput(cnode, sub_func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphOutput failed.";
|
||||
return false;
|
||||
}
|
||||
ret = SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret;
|
||||
|
@ -278,7 +288,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
STATUS InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
MS_ASSERT(return_node != nullptr);
|
||||
|
@ -307,6 +317,7 @@ void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr
|
|||
trans_cnode->set_fullname_with_scope(trans_input_name);
|
||||
}
|
||||
return_node->set_inputs(origin_input);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
|
@ -360,7 +371,7 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool InferShapePass::ResetSubGraphInput() {
|
||||
int InferShapePass::ResetSubGraphInput() {
|
||||
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||
auto &sub_graph = iter->first;
|
||||
auto &sub_inputs = iter->second;
|
||||
|
@ -368,7 +379,7 @@ bool InferShapePass::ResetSubGraphInput() {
|
|||
MS_ASSERT(manager != nullptr);
|
||||
for (auto &sub_input : sub_inputs) {
|
||||
auto param_node = sub_graph->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, false, "Add parameter Failed");
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "Add parameter Failed");
|
||||
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||
param_node->set_name(sub_input->fullname_with_scope());
|
||||
manager->Replace(sub_input, param_node);
|
||||
|
@ -377,7 +388,7 @@ bool InferShapePass::ResetSubGraphInput() {
|
|||
sub_param_input->set_default_param(nullptr);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,9 +36,9 @@ class InferShapePass : public Pass {
|
|||
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
||||
STATUS InferProcess(const FuncGraphPtr &func_graph);
|
||||
STATUS SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
STATUS SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
STATUS SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
bool ResetSubGraphInput();
|
||||
int ResetSubGraphInput();
|
||||
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
|
|
|
@ -90,16 +90,16 @@ std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, Format
|
|||
return cur_axes;
|
||||
}
|
||||
|
||||
void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
|
||||
const std::vector<int> &axes, FormatTransNodeType trans_type,
|
||||
NodeInferShape *node_infer_shape) {
|
||||
int TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
|
||||
const std::vector<int> &axes, FormatTransNodeType trans_type,
|
||||
NodeInferShape *node_infer_shape) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
|
||||
if (input_index >= cnode->size() || axes.empty()) {
|
||||
return;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index);
|
||||
if (origin_input.size() != axes.size()) {
|
||||
return;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<int> cur_input;
|
||||
for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
|
||||
|
@ -119,7 +119,9 @@ void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|||
}
|
||||
}
|
||||
auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "BuildIntVecParameterNode failed");
|
||||
func_graph->manager()->Replace(cnode->input(input_index), param_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
|
||||
|
@ -266,6 +268,7 @@ STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Form
|
|||
}
|
||||
int element_num = shape.front();
|
||||
auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
|
||||
std::vector<int> axes;
|
||||
if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
|
||||
for (int index = 0; index < element_num; ++index) {
|
||||
|
@ -314,6 +317,7 @@ STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode
|
|||
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
|
||||
auto param_node =
|
||||
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
|
||||
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -72,6 +72,7 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
|
||||
auto transpose_cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(transpose_cnode != nullptr);
|
||||
if (!CheckPrimitiveType(transpose_cnode->input(kTransposeInput), prim::kPrimConv2DFusion)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -90,6 +91,7 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
continue;
|
||||
}
|
||||
auto transpose_cnode = conv_node->input(kTransposeInput)->cast<CNodePtr>();
|
||||
MS_ASSERT(transpose_cnode != nullptr);
|
||||
auto perm = GetTransposePerm(transpose_cnode);
|
||||
if (perm == kPermNHWC) {
|
||||
manager->Replace(transpose_cnode, transpose_cnode->input(1));
|
||||
|
|
Loading…
Reference in New Issue