spilt valuenode & parameter's tuple output to maketuple

This commit is contained in:
WilliamLian 2020-07-18 14:34:38 +08:00
parent e984f3ecce
commit d10d1a17f0
6 changed files with 196 additions and 171 deletions

View File

@ -47,8 +47,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>()); common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>()); common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>()); common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>()); common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
optimizer->AddPassManager(common_pm); optimizer->AddPassManager(common_pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();

View File

@ -27,86 +27,33 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
bool MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return false;
}
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
TypeId infer_data_type;
if (AnfAlgo::GetOutputTensorNum(value_node) == 0) {
infer_data_type = kTypeUnknown;
} else {
infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0);
}
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{infer_data_type});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
return true;
}
void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node,
std::vector<AnfNodePtr> *plant_inputs, std::vector<int> *dyn_input_sizes) {
MS_EXCEPTION_IF_NULL(plant_inputs);
MS_EXCEPTION_IF_NULL(dyn_input_sizes);
MS_EXCEPTION_IF_NULL(graph);
auto output_size = AnfAlgo::GetOutputTensorNum(input_node);
dyn_input_sizes->push_back(output_size);
std::vector<AnfNodePtr> convert_inputs;
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
if (input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node);
} else {
for (size_t index = 0; index < output_size; ++index) {
auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)},
{AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get());
convert_inputs.emplace_back(tuple_get_item);
}
}
(void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs));
}
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto &ori_args = cnode_ptr->inputs();
if (ori_args.size() < 1) {
return;
}
std::vector<AnfNodePtr> plant_inputs; std::vector<AnfNodePtr> plant_inputs;
std::vector<int> dyn_input_sizes; std::vector<int> dyn_input_sizes;
plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]); plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
for (size_t i = 1; i < ori_args.size(); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) {
auto input_node = ori_args[i]; auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) { MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
auto input_size = AnfAlgo::GetOutputTensorNum(input_node); auto input_size = AnfAlgo::GetOutputTensorNum(input_node);
dyn_input_sizes.push_back(input_size); dyn_input_sizes.push_back(input_size);
auto cnode = input_node->cast<CNodePtr>(); auto make_tuple = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(make_tuple);
auto inputs = cnode->inputs(); for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) {
for (size_t j = 1; j < inputs.size(); ++j) { auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(inputs[j]); MS_EXCEPTION_IF_NULL(dyn_input_node);
if (IsValueNode<tensor::Tensor>(inputs[j])) { if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
auto success = MakeValueNode(inputs[j]); auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>());
if (!success) { if (!success) {
MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString(); MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString();
} }
} }
plant_inputs.push_back(inputs[j]); plant_inputs.push_back(dyn_input_node);
} }
} else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) {
ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes);
} else { } else {
dyn_input_sizes.push_back(-1); dyn_input_sizes.push_back(-1);
plant_inputs.push_back(input_node); plant_inputs.push_back(input_node);
@ -139,9 +86,8 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
for (auto &t : todos) { for (auto &t : todos) {
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>()); ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
} }
} else {
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
} }
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
return node; return node;
} }
} // namespace opt } // namespace opt

View File

@ -25,6 +25,38 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
CNodePtr ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(graph);
if (!AnfAlgo::IsTupleOutput(input_node)) {
MS_LOG(EXCEPTION) << "Cannot using the function to convert a not tuple output node to maketuple!";
}
if (input_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << input_node->DebugString();
}
std::vector<AnfNodePtr> convert_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto splited_node_list = kernel_graph->SplitTupleOutputNodeToNodeList(input_node);
for (const auto &node : splited_node_list) {
if (AnfAlgo::IsTupleOutput(node)) {
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, node));
continue;
}
convert_inputs.emplace_back(node);
}
auto make_tuple = graph->NewCNode(convert_inputs);
std::vector<abstract::AbstractBasePtr> abstract_list;
auto make_tuple_input_size = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t index = 0; index < make_tuple_input_size; ++index) {
auto make_tuple_input = AnfAlgo::GetInputNode(make_tuple, index);
MS_EXCEPTION_IF_NULL(make_tuple_input);
abstract_list.emplace_back(make_tuple_input->abstract());
}
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple;
}
CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@ -35,6 +67,7 @@ CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr
std::vector<TypeId> types; std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes; std::vector<std::vector<size_t>> shapes;
std::vector<AnfNodePtr> make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; std::vector<AnfNodePtr> make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)};
if (input_node->isa<CNode>()) {
for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) {
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index));
types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index));
@ -43,11 +76,16 @@ CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr
auto make_tuple = graph->NewCNode(make_tuple_inputs_list); auto make_tuple = graph->NewCNode(make_tuple_inputs_list);
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
convert_inputs.emplace_back(make_tuple); convert_inputs.emplace_back(make_tuple);
continue;
}
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, input_node));
} else { } else {
convert_inputs.push_back(input_node); convert_inputs.push_back(input_node);
} }
} }
return graph->NewCNode(convert_inputs); auto new_node = graph->NewCNode(convert_inputs);
new_node->set_abstract(cnode_ptr->abstract());
return new_node;
} }
} // namespace } // namespace

View File

@ -79,31 +79,6 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
return real_inputs; return real_inputs;
} }
AnfNodePtr MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
new_value_node->set_abstract(value_node->abstract());
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
types.push_back(kTypeUnknown);
}
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
return new_value_node;
}
bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
if (left == right) { if (left == right) {
return true; return true;
@ -121,6 +96,18 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
return false; return false;
} }
} // namespace } // namespace
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
new_value_node->set_abstract(value_node->abstract());
this->SetKernelInfoForNode(new_value_node);
return new_value_node;
}
std::vector<AnfNodePtr> KernelGraph::outputs() const { std::vector<AnfNodePtr> KernelGraph::outputs() const {
auto graph_output = output(); auto graph_output = output();
if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
@ -290,28 +277,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
CreateKernelInfoFromNewParameter(cnode); CreateKernelInfoFromNewParameter(cnode);
auto kernel_info = std::make_shared<device::KernelInfo>();
std::vector<size_t> feature_map_input_indexs;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
for (size_t index = 1; index < inputs.size(); ++index) {
auto node = inputs[index];
if (AnfAlgo::IsFeatureMapOutput(node)) {
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
} }
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { SetKernelInfoForNode(cnode);
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealKernel(cnode)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
}
cnode->set_kernel_info(kernel_info);
AnfAlgo::SetGraphId(graph_id_, cnode.get()); AnfAlgo::SetGraphId(graph_id_, cnode.get());
return cnode; return cnode;
} }
@ -351,6 +320,50 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
} }
} }
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = std::make_shared<device::KernelInfo>();
node->set_kernel_info(kernel_info);
if (node->isa<CNode>()) {
std::vector<size_t> feature_map_input_indexs;
kernel_info->SetFeatureMapFlag(false);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
if (AnfAlgo::IsFeatureMapInput(node, index)) {
kernel_info->SetFeatureMapFlag(true);
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetInputTensorNum(node) == 0) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealKernel(node)) {
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
}
return;
}
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
std::vector<TypeId> types;
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
if (node->isa<ValueNode>()) {
kernel_info->SetFeatureMapFlag(false);
types.emplace_back(kTypeUnknown);
}
if (node->isa<Parameter>()) {
auto parameter = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
bool is_weight = AnfAlgo ::IsParameterWeight(parameter);
kernel_info->SetFeatureMapFlag(!is_weight);
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
}
// set parameter initaial device data type
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
}
CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto new_cnode = std::make_shared<CNode>(*cnode); auto new_cnode = std::make_shared<CNode>(*cnode);
@ -366,75 +379,97 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
} }
ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) { ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
ParameterPtr new_parameter = add_parameter(); auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
auto new_parameter = NewParameter(abstract);
MS_EXCEPTION_IF_NULL(new_parameter); MS_EXCEPTION_IF_NULL(new_parameter);
// create kernel_info form new parameter
auto kernel_info = std::make_shared<device::KernelInfo>();
size_t output_tensor_num = 1;
// if use default parameter = nullptr,it remarks create a new parameter from no parameter
if (parameter == nullptr) {
new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
kernel_info->SetFeatureMapFlag(true);
} else {
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
new_parameter->set_abstract(parameter->abstract()); if (parameter != nullptr) {
new_parameter->set_name(parameter->name()); new_parameter->set_name(parameter->name());
if (AnfAlgo::IsParameterWeight(parameter)) { if (AnfAlgo::IsParameterWeight(parameter)) {
new_parameter->set_default_param(parameter->default_param()); new_parameter->set_default_param(parameter->default_param());
kernel_info->SetFeatureMapFlag(false);
} else {
kernel_info->SetFeatureMapFlag(true);
} }
} }
new_parameter->set_kernel_info(kernel_info); // create kernel_info form new parameter
// create kernel_build_info for new parameter SetKernelInfoForNode(new_parameter);
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// create init data type,
std::vector<TypeId> init_data_type = {};
TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0);
init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
// set the format of parameter to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
// set parameter initaial device data type
kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
return new_parameter; return new_parameter;
} }
ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
ParameterPtr new_parameter = add_parameter();
new_parameter->set_abstract(abstract);
MS_EXCEPTION_IF_NULL(new_parameter);
// create kernel_info form new parameter
SetKernelInfoForNode(new_parameter);
AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
return new_parameter;
}
std::vector<AnfNodePtr> KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
std::vector<AnfNodePtr> convert_nodes_list;
auto abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (!abstract->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString();
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
auto new_parameter = this->NewParameter((*tuple_abstract)[index]);
SetKernelInfoForNode(new_parameter);
convert_nodes_list.emplace_back(new_parameter);
}
auto new_inputs = std::make_shared<std::vector<AnfNodePtr>>();
auto old_inputs = inputs();
for (const auto &input_node : old_inputs) {
if (input_node != parameter) {
new_inputs->emplace_back(input_node);
continue;
}
std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs));
}
inputs_ = new_inputs;
return convert_nodes_list;
}
std::vector<AnfNodePtr> KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString();
}
if (node->isa<Parameter>()) {
return SplitTupleParameterToNodeList(node->cast<ParameterPtr>());
}
return SplitTupleValueNodeToNodeList(node->cast<ValueNodePtr>());
}
std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto node_value = value_node->value(); auto node_value = value_node->value();
auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
std::vector<AnfNodePtr> convert_inputs; std::vector<AnfNodePtr> convert_inputs;
if (!node_value->isa<ValueTuple>()) { if (!node_value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
} }
auto value_tuple = node_value->cast<ValueTuplePtr>(); auto value_tuple = node_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple); MS_EXCEPTION_IF_NULL(value_tuple);
if (value_tuple->size() != output_size) { auto abstract = value_node->abstract();
MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size() if (!abstract->isa<abstract::AbstractTuple>()) {
<< " is not mathced with the value node's output size" << output_size; MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple";
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
if (tuple_abstract->size() != value_tuple->size()) {
MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range "
<< tuple_abstract->size();
} }
for (size_t index = 0; index < value_tuple->value().size(); ++index) { for (size_t index = 0; index < value_tuple->value().size(); ++index) {
auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]); auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)}, new_value_node->set_abstract((*tuple_abstract)[index]);
{AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
AddValueNodeToGraph(new_value_node); AddValueNodeToGraph(new_value_node);
auto kernel_info = std::make_shared<device::KernelInfo>(); SetKernelInfoForNode(new_value_node);
new_value_node->set_kernel_info(kernel_info);
kernel_info->SetFeatureMapFlag(false);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
AddValueNodeToGraph(new_value_node);
convert_inputs.emplace_back(new_value_node); convert_inputs.emplace_back(new_value_node);
} }
if (!RemoveValueNodeFromGraph(value_node)) { if (!RemoveValueNodeFromGraph(value_node)) {

View File

@ -54,8 +54,10 @@ class KernelGraph : public FuncGraph {
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
CNodePtr NewCNode(const CNodePtr &cnode); CNodePtr NewCNode(const CNodePtr &cnode);
ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr); ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
ValueNodePtr NewValueNode(const ValuePtr &value);
ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); std::vector<AnfNodePtr> SplitTupleOutputNodeToNodeList(const AnfNodePtr &node);
void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; } void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
const std::vector<CNodePtr> &execution_order() const { return execution_order_; } const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
void SetExecOrderByDefault(); void SetExecOrderByDefault();
@ -166,6 +168,10 @@ class KernelGraph : public FuncGraph {
private: private:
// remove value node form graph // remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
void SetKernelInfoForNode(const AnfNodePtr &node) const;
std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
std::vector<AnfNodePtr> SplitTupleParameterToNodeList(const ParameterPtr &parameter);
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes); std::unordered_set<AnfNodePtr> *visited_nodes);
// update node edge list // update node edge list

View File

@ -60,7 +60,7 @@ TEST_F(KernelGraphTest, NewParameter) {
auto anf_graph = std::make_shared<FuncGraph>(); auto anf_graph = std::make_shared<FuncGraph>();
auto kernel_graph = std::make_shared<KernelGraph>(); auto kernel_graph = std::make_shared<KernelGraph>();
// test nullptr as input // test nullptr as input
auto new_paramter = kernel_graph->NewParameter(nullptr); auto new_paramter = kernel_graph->NewParameter();
EXPECT_NE(new_paramter, nullptr); EXPECT_NE(new_paramter, nullptr);
EXPECT_TRUE(new_paramter->isa<Parameter>()); EXPECT_TRUE(new_paramter->isa<Parameter>());
EXPECT_EQ(AnfAlgo::GetOutputFormat(new_paramter, 0), kOpFormat_DEFAULT); EXPECT_EQ(AnfAlgo::GetOutputFormat(new_paramter, 0), kOpFormat_DEFAULT);