forked from OSSInnovation/mindspore
spilt valuenode & parameter's tuple output to maketuple
This commit is contained in:
parent
e984f3ecce
commit
d10d1a17f0
|
@ -47,8 +47,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
|
||||
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
|
||||
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
|
||||
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
||||
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
|
||||
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
||||
optimizer->AddPassManager(common_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -27,86 +27,33 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool MakeValueNode(const AnfNodePtr &node) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// create kernel_info fo new value node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
TypeId infer_data_type;
|
||||
if (AnfAlgo::GetOutputTensorNum(value_node) == 0) {
|
||||
infer_data_type = kTypeUnknown;
|
||||
} else {
|
||||
infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0);
|
||||
}
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{infer_data_type});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
|
||||
return true;
|
||||
}
|
||||
|
||||
void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node,
|
||||
std::vector<AnfNodePtr> *plant_inputs, std::vector<int> *dyn_input_sizes) {
|
||||
MS_EXCEPTION_IF_NULL(plant_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_sizes);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto output_size = AnfAlgo::GetOutputTensorNum(input_node);
|
||||
dyn_input_sizes->push_back(output_size);
|
||||
std::vector<AnfNodePtr> convert_inputs;
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node);
|
||||
} else {
|
||||
for (size_t index = 0; index < output_size; ++index) {
|
||||
auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)},
|
||||
{AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get());
|
||||
convert_inputs.emplace_back(tuple_get_item);
|
||||
}
|
||||
}
|
||||
(void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs));
|
||||
}
|
||||
|
||||
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto &ori_args = cnode_ptr->inputs();
|
||||
if (ori_args.size() < 1) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> plant_inputs;
|
||||
std::vector<int> dyn_input_sizes;
|
||||
plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]);
|
||||
for (size_t i = 1; i < ori_args.size(); ++i) {
|
||||
auto input_node = ori_args[i];
|
||||
if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
|
||||
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) {
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
|
||||
auto input_size = AnfAlgo::GetOutputTensorNum(input_node);
|
||||
dyn_input_sizes.push_back(input_size);
|
||||
auto cnode = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
for (size_t j = 1; j < inputs.size(); ++j) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[j]);
|
||||
if (IsValueNode<tensor::Tensor>(inputs[j])) {
|
||||
auto success = MakeValueNode(inputs[j]);
|
||||
auto make_tuple = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) {
|
||||
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>());
|
||||
if (!success) {
|
||||
MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString();
|
||||
MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString();
|
||||
}
|
||||
}
|
||||
plant_inputs.push_back(inputs[j]);
|
||||
plant_inputs.push_back(dyn_input_node);
|
||||
}
|
||||
} else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) {
|
||||
ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes);
|
||||
} else {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
plant_inputs.push_back(input_node);
|
||||
|
@ -139,9 +86,8 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
|
|||
for (auto &t : todos) {
|
||||
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
|
||||
}
|
||||
} else {
|
||||
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
||||
}
|
||||
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -25,6 +25,38 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!AnfAlgo::IsTupleOutput(input_node)) {
|
||||
MS_LOG(EXCEPTION) << "Cannot using the function to convert a not tuple output node to maketuple!";
|
||||
}
|
||||
if (input_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << input_node->DebugString();
|
||||
}
|
||||
std::vector<AnfNodePtr> convert_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto splited_node_list = kernel_graph->SplitTupleOutputNodeToNodeList(input_node);
|
||||
for (const auto &node : splited_node_list) {
|
||||
if (AnfAlgo::IsTupleOutput(node)) {
|
||||
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, node));
|
||||
continue;
|
||||
}
|
||||
convert_inputs.emplace_back(node);
|
||||
}
|
||||
|
||||
auto make_tuple = graph->NewCNode(convert_inputs);
|
||||
std::vector<abstract::AbstractBasePtr> abstract_list;
|
||||
auto make_tuple_input_size = AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t index = 0; index < make_tuple_input_size; ++index) {
|
||||
auto make_tuple_input = AnfAlgo::GetInputNode(make_tuple, index);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_input);
|
||||
abstract_list.emplace_back(make_tuple_input->abstract());
|
||||
}
|
||||
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return make_tuple;
|
||||
}
|
||||
|
||||
CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -35,6 +67,7 @@ CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr
|
|||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
std::vector<AnfNodePtr> make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
if (input_node->isa<CNode>()) {
|
||||
for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) {
|
||||
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index));
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index));
|
||||
|
@ -43,11 +76,16 @@ CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr
|
|||
auto make_tuple = graph->NewCNode(make_tuple_inputs_list);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
|
||||
convert_inputs.emplace_back(make_tuple);
|
||||
continue;
|
||||
}
|
||||
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, input_node));
|
||||
} else {
|
||||
convert_inputs.push_back(input_node);
|
||||
}
|
||||
}
|
||||
return graph->NewCNode(convert_inputs);
|
||||
auto new_node = graph->NewCNode(convert_inputs);
|
||||
new_node->set_abstract(cnode_ptr->abstract());
|
||||
return new_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -79,31 +79,6 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
|
|||
return real_inputs;
|
||||
}
|
||||
|
||||
AnfNodePtr MakeValueNode(const AnfNodePtr &node) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create kernel_info fo new value node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
std::vector<TypeId> types;
|
||||
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
|
||||
types.push_back(kTypeUnknown);
|
||||
}
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
|
||||
if (left == right) {
|
||||
return true;
|
||||
|
@ -121,6 +96,18 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
|
|||
return false;
|
||||
}
|
||||
} // namespace
|
||||
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
this->SetKernelInfoForNode(new_value_node);
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
||||
auto graph_output = output();
|
||||
if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
|
||||
|
@ -290,28 +277,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
CreateKernelInfoFromNewParameter(cnode);
|
||||
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
std::vector<size_t> feature_map_input_indexs;
|
||||
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
||||
// then the node's output is a feature map output
|
||||
for (size_t index = 1; index < inputs.size(); ++index) {
|
||||
auto node = inputs[index];
|
||||
if (AnfAlgo::IsFeatureMapOutput(node)) {
|
||||
feature_map_input_indexs.push_back(index);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||
}
|
||||
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
}
|
||||
if (AnfAlgo::IsRealKernel(cnode)) {
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
|
||||
}
|
||||
cnode->set_kernel_info(kernel_info);
|
||||
SetKernelInfoForNode(cnode);
|
||||
AnfAlgo::SetGraphId(graph_id_, cnode.get());
|
||||
return cnode;
|
||||
}
|
||||
|
@ -351,6 +320,50 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
node->set_kernel_info(kernel_info);
|
||||
if (node->isa<CNode>()) {
|
||||
std::vector<size_t> feature_map_input_indexs;
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
|
||||
if (AnfAlgo::IsFeatureMapInput(node, index)) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
feature_map_input_indexs.push_back(index);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetInputTensorNum(node) == 0) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
}
|
||||
if (AnfAlgo::IsRealKernel(node)) {
|
||||
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
||||
// then the node's output is a feature map output
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
|
||||
}
|
||||
return;
|
||||
}
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
std::vector<TypeId> types;
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
if (node->isa<ValueNode>()) {
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
types.emplace_back(kTypeUnknown);
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
auto parameter = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
bool is_weight = AnfAlgo ::IsParameterWeight(parameter);
|
||||
kernel_info->SetFeatureMapFlag(!is_weight);
|
||||
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
|
||||
}
|
||||
// set parameter initaial device data type
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
|
||||
}
|
||||
|
||||
CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto new_cnode = std::make_shared<CNode>(*cnode);
|
||||
|
@ -366,75 +379,97 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) {
|
||||
ParameterPtr new_parameter = add_parameter();
|
||||
auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
|
||||
auto new_parameter = NewParameter(abstract);
|
||||
MS_EXCEPTION_IF_NULL(new_parameter);
|
||||
// create kernel_info form new parameter
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
size_t output_tensor_num = 1;
|
||||
// if use default parameter = nullptr,it remarks create a new parameter from no parameter
|
||||
if (parameter == nullptr) {
|
||||
new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
} else {
|
||||
|
||||
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
|
||||
new_parameter->set_abstract(parameter->abstract());
|
||||
if (parameter != nullptr) {
|
||||
new_parameter->set_name(parameter->name());
|
||||
if (AnfAlgo::IsParameterWeight(parameter)) {
|
||||
new_parameter->set_default_param(parameter->default_param());
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
} else {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
}
|
||||
}
|
||||
new_parameter->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new parameter
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// create init data type,
|
||||
std::vector<TypeId> init_data_type = {};
|
||||
|
||||
TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0);
|
||||
init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
|
||||
|
||||
// set the format of parameter to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
|
||||
// set parameter initaial device data type
|
||||
kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
|
||||
// create kernel_info form new parameter
|
||||
SetKernelInfoForNode(new_parameter);
|
||||
AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
|
||||
return new_parameter;
|
||||
}
|
||||
|
||||
ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
|
||||
ParameterPtr new_parameter = add_parameter();
|
||||
new_parameter->set_abstract(abstract);
|
||||
MS_EXCEPTION_IF_NULL(new_parameter);
|
||||
// create kernel_info form new parameter
|
||||
SetKernelInfoForNode(new_parameter);
|
||||
AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
|
||||
return new_parameter;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
std::vector<AnfNodePtr> convert_nodes_list;
|
||||
auto abstract = parameter->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (!abstract->isa<abstract::AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString();
|
||||
}
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
|
||||
auto new_parameter = this->NewParameter((*tuple_abstract)[index]);
|
||||
SetKernelInfoForNode(new_parameter);
|
||||
convert_nodes_list.emplace_back(new_parameter);
|
||||
}
|
||||
auto new_inputs = std::make_shared<std::vector<AnfNodePtr>>();
|
||||
auto old_inputs = inputs();
|
||||
for (const auto &input_node : old_inputs) {
|
||||
if (input_node != parameter) {
|
||||
new_inputs->emplace_back(input_node);
|
||||
continue;
|
||||
}
|
||||
std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs));
|
||||
}
|
||||
inputs_ = new_inputs;
|
||||
return convert_nodes_list;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString();
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
return SplitTupleParameterToNodeList(node->cast<ParameterPtr>());
|
||||
}
|
||||
return SplitTupleValueNodeToNodeList(node->cast<ValueNodePtr>());
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto node_value = value_node->value();
|
||||
auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
|
||||
std::vector<AnfNodePtr> convert_inputs;
|
||||
if (!node_value->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
|
||||
}
|
||||
auto value_tuple = node_value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
if (value_tuple->size() != output_size) {
|
||||
MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size()
|
||||
<< " is not mathced with the value node's output size" << output_size;
|
||||
auto abstract = value_node->abstract();
|
||||
if (!abstract->isa<abstract::AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple";
|
||||
}
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
if (tuple_abstract->size() != value_tuple->size()) {
|
||||
MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range "
|
||||
<< tuple_abstract->size();
|
||||
}
|
||||
for (size_t index = 0; index < value_tuple->value().size(); ++index) {
|
||||
auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)},
|
||||
{AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
|
||||
new_value_node->set_abstract((*tuple_abstract)[index]);
|
||||
AddValueNodeToGraph(new_value_node);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
SetKernelInfoForNode(new_value_node);
|
||||
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
|
||||
AddValueNodeToGraph(new_value_node);
|
||||
convert_inputs.emplace_back(new_value_node);
|
||||
}
|
||||
if (!RemoveValueNodeFromGraph(value_node)) {
|
||||
|
|
|
@ -54,8 +54,10 @@ class KernelGraph : public FuncGraph {
|
|||
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
|
||||
CNodePtr NewCNode(const CNodePtr &cnode);
|
||||
ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr);
|
||||
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
|
||||
ValueNodePtr NewValueNode(const ValuePtr &value);
|
||||
ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
|
||||
std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
|
||||
std::vector<AnfNodePtr> SplitTupleOutputNodeToNodeList(const AnfNodePtr &node);
|
||||
void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
|
||||
const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
|
||||
void SetExecOrderByDefault();
|
||||
|
@ -166,6 +168,10 @@ class KernelGraph : public FuncGraph {
|
|||
private:
|
||||
// remove value node form graph
|
||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||
void SetKernelInfoForNode(const AnfNodePtr &node) const;
|
||||
std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
|
||||
std::vector<AnfNodePtr> SplitTupleParameterToNodeList(const ParameterPtr ¶meter);
|
||||
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
|
||||
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes);
|
||||
// update node edge list
|
||||
|
|
|
@ -60,7 +60,7 @@ TEST_F(KernelGraphTest, NewParameter) {
|
|||
auto anf_graph = std::make_shared<FuncGraph>();
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
// test nullptr as input
|
||||
auto new_paramter = kernel_graph->NewParameter(nullptr);
|
||||
auto new_paramter = kernel_graph->NewParameter();
|
||||
EXPECT_NE(new_paramter, nullptr);
|
||||
EXPECT_TRUE(new_paramter->isa<Parameter>());
|
||||
EXPECT_EQ(AnfAlgo::GetOutputFormat(new_paramter, 0), kOpFormat_DEFAULT);
|
||||
|
|
Loading…
Reference in New Issue