forked from mindspore-Ecosystem/mindspore
split tuple parameter to parameters
add function trans tuple to maketuple
This commit is contained in:
parent
49053e7f83
commit
1b6f85dec8
|
@ -46,8 +46,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
||||||
auto common_pm = std::make_shared<PassManager>("common_pm");
|
auto common_pm = std::make_shared<PassManager>("common_pm");
|
||||||
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
|
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
|
||||||
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
|
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
|
||||||
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
|
|
||||||
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
|
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
|
||||||
|
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
|
||||||
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
||||||
optimizer->AddPassManager(common_pm);
|
optimizer->AddPassManager(common_pm);
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
|
|
|
@ -139,7 +139,10 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
|
||||||
|
|
||||||
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
if (AnfAlgo::IsGraphKernel(node)) {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
@ -25,68 +26,26 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
CNodePtr ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node) {
|
AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf,
|
||||||
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *transed_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_anf);
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
if (!AnfAlgo::IsTupleOutput(input_node)) {
|
MS_EXCEPTION_IF_NULL(transed_nodes);
|
||||||
MS_LOG(EXCEPTION) << "Cannot using the function to convert a not tuple output node to maketuple!";
|
|
||||||
|
if (!AnfAlgo::IsTupleOutput(tuple_anf)) {
|
||||||
|
return tuple_anf;
|
||||||
}
|
}
|
||||||
if (input_node->isa<CNode>()) {
|
auto transed_node_it = transed_nodes->find(tuple_anf);
|
||||||
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << input_node->DebugString();
|
if (transed_node_it != transed_nodes->end()) {
|
||||||
|
return transed_node_it->second;
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> convert_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
||||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf);
|
||||||
auto splited_node_list = kernel_graph->SplitTupleOutputNodeToNodeList(input_node);
|
(*transed_nodes)[tuple_anf] = make_tuple;
|
||||||
for (const auto &node : splited_node_list) {
|
// replace graph inputs if input is a parameter
|
||||||
if (AnfAlgo::IsTupleOutput(node)) {
|
kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple);
|
||||||
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, node));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
convert_inputs.emplace_back(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto make_tuple = graph->NewCNode(convert_inputs);
|
|
||||||
std::vector<abstract::AbstractBasePtr> abstract_list;
|
|
||||||
auto make_tuple_input_size = AnfAlgo::GetInputTensorNum(make_tuple);
|
|
||||||
for (size_t index = 0; index < make_tuple_input_size; ++index) {
|
|
||||||
auto make_tuple_input = AnfAlgo::GetInputNode(make_tuple, index);
|
|
||||||
MS_EXCEPTION_IF_NULL(make_tuple_input);
|
|
||||||
abstract_list.emplace_back(make_tuple_input->abstract());
|
|
||||||
}
|
|
||||||
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
|
||||||
return make_tuple;
|
return make_tuple;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
std::vector<AnfNodePtr> convert_inputs = {cnode_ptr->input(0)};
|
|
||||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) {
|
|
||||||
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index);
|
|
||||||
if (AnfAlgo::IsTupleOutput(input_node)) {
|
|
||||||
std::vector<TypeId> types;
|
|
||||||
std::vector<std::vector<size_t>> shapes;
|
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)};
|
|
||||||
if (input_node->isa<CNode>()) {
|
|
||||||
for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) {
|
|
||||||
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index));
|
|
||||||
types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index));
|
|
||||||
shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index));
|
|
||||||
}
|
|
||||||
auto make_tuple = graph->NewCNode(make_tuple_inputs_list);
|
|
||||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
|
|
||||||
convert_inputs.emplace_back(make_tuple);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, input_node));
|
|
||||||
} else {
|
|
||||||
convert_inputs.push_back(input_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto new_node = graph->NewCNode(convert_inputs);
|
|
||||||
new_node->set_abstract(cnode_ptr->abstract());
|
|
||||||
return new_node;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const {
|
const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const {
|
||||||
|
@ -102,15 +61,22 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
||||||
}
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) {
|
bool cnode_input_changed = false;
|
||||||
return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node);
|
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
|
||||||
})) {
|
const auto &input = cnode->inputs()[i];
|
||||||
return ConvertTupleInputToMakeTuple(func_graph, cnode);
|
if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
|
||||||
|
!AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
|
||||||
|
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes));
|
||||||
|
cnode_input_changed = true;
|
||||||
}
|
}
|
||||||
return nullptr;
|
}
|
||||||
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
return cnode_input_changed ? kernel_graph->NewCNode(cnode) : nullptr;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1810,7 +1810,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
|
||||||
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
|
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
|
||||||
// auto multi_output_param = graph->NewParameter();
|
// auto multi_output_param = graph->NewParameter();
|
||||||
auto origin_inputs = graph->inputs();
|
auto origin_inputs = graph->inputs();
|
||||||
auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get());
|
auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
|
||||||
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
|
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
|
||||||
graph->MutableInputs()->operator=(origin_inputs);
|
graph->MutableInputs()->operator=(origin_inputs);
|
||||||
graph->AddChildGraphResult(output_param);
|
graph->AddChildGraphResult(output_param);
|
||||||
|
@ -1828,9 +1828,8 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
|
||||||
if (child_graph->get_output_null()) {
|
if (child_graph->get_output_null()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto graph_output = child_graph->output();
|
AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr,
|
||||||
AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output),
|
NOT_NULL(child_graph->output()), NOT_NULL(output_param));
|
||||||
NOT_NULL(output_param));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -406,78 +406,6 @@ ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract
|
||||||
return new_parameter;
|
return new_parameter;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr ¶meter) {
|
|
||||||
MS_EXCEPTION_IF_NULL(parameter);
|
|
||||||
std::vector<AnfNodePtr> convert_nodes_list;
|
|
||||||
auto abstract = parameter->abstract();
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
|
||||||
if (!abstract->isa<abstract::AbstractTuple>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString();
|
|
||||||
}
|
|
||||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
|
||||||
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
|
|
||||||
auto new_parameter = this->NewParameter((*tuple_abstract)[index]);
|
|
||||||
SetKernelInfoForNode(new_parameter);
|
|
||||||
convert_nodes_list.emplace_back(new_parameter);
|
|
||||||
}
|
|
||||||
auto new_inputs = std::make_shared<std::vector<AnfNodePtr>>();
|
|
||||||
auto old_inputs = inputs();
|
|
||||||
for (const auto &input_node : old_inputs) {
|
|
||||||
if (input_node != parameter) {
|
|
||||||
new_inputs->emplace_back(input_node);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs));
|
|
||||||
}
|
|
||||||
inputs_ = new_inputs;
|
|
||||||
return convert_nodes_list;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
if (node->isa<CNode>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString();
|
|
||||||
}
|
|
||||||
if (node->isa<Parameter>()) {
|
|
||||||
return SplitTupleParameterToNodeList(node->cast<ParameterPtr>());
|
|
||||||
}
|
|
||||||
return SplitTupleValueNodeToNodeList(node->cast<ValueNodePtr>());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
|
||||||
auto node_value = value_node->value();
|
|
||||||
std::vector<AnfNodePtr> convert_inputs;
|
|
||||||
if (!node_value->isa<ValueTuple>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
|
|
||||||
}
|
|
||||||
auto value_tuple = node_value->cast<ValueTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
|
||||||
auto abstract = value_node->abstract();
|
|
||||||
if (!abstract->isa<abstract::AbstractTuple>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple";
|
|
||||||
}
|
|
||||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
|
||||||
if (tuple_abstract->size() != value_tuple->size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range "
|
|
||||||
<< tuple_abstract->size();
|
|
||||||
}
|
|
||||||
for (size_t index = 0; index < value_tuple->value().size(); ++index) {
|
|
||||||
auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
|
|
||||||
new_value_node->set_abstract((*tuple_abstract)[index]);
|
|
||||||
AddValueNodeToGraph(new_value_node);
|
|
||||||
SetKernelInfoForNode(new_value_node);
|
|
||||||
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
|
|
||||||
convert_inputs.emplace_back(new_value_node);
|
|
||||||
}
|
|
||||||
if (!RemoveValueNodeFromGraph(value_node)) {
|
|
||||||
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
|
|
||||||
}
|
|
||||||
return convert_inputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
|
auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
|
||||||
|
@ -485,6 +413,110 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
||||||
return new_value_node;
|
return new_value_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value);
|
||||||
|
new_value_node->set_abstract(abstract);
|
||||||
|
SetKernelInfoForNode(new_value_node);
|
||||||
|
AnfAlgo::SetGraphId(graph_id(), new_value_node.get());
|
||||||
|
return new_value_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (!abstract->isa<abstract::AbstractTuple>()) {
|
||||||
|
auto new_value_node = NewValueNode(abstract, value);
|
||||||
|
AddValueNodeToGraph(new_value_node);
|
||||||
|
return new_value_node;
|
||||||
|
}
|
||||||
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||||
|
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||||
|
if (tuple_abstract->size() != value_tuple->size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size()
|
||||||
|
<< " is not equal to value size:" << value_tuple->size();
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> make_tuple_inputs = {
|
||||||
|
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||||
|
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
|
||||||
|
make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index]));
|
||||||
|
}
|
||||||
|
auto make_tuple = NewCNode(make_tuple_inputs);
|
||||||
|
make_tuple->set_abstract(tuple_abstract);
|
||||||
|
return make_tuple;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) {
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
if (!abstract->isa<abstract::AbstractTuple>()) {
|
||||||
|
return NewParameter(abstract);
|
||||||
|
}
|
||||||
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
|
std::vector<AnfNodePtr> make_tuple_inputs = {
|
||||||
|
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||||
|
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
|
||||||
|
make_tuple_inputs.push_back(TransParameterTuple((*tuple_abstract)[index]));
|
||||||
|
}
|
||||||
|
auto make_tuple = NewCNode(make_tuple_inputs);
|
||||||
|
make_tuple->set_abstract(tuple_abstract);
|
||||||
|
return make_tuple;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) {
|
||||||
|
auto idx = mindspore::NewValueNode(SizeToInt(output_idx));
|
||||||
|
MS_EXCEPTION_IF_NULL(idx);
|
||||||
|
auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx));
|
||||||
|
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||||
|
idx->set_abstract(abstract_scalar);
|
||||||
|
AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||||
|
tuple_getitem->set_scope(node->scope());
|
||||||
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||||
|
TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
|
||||||
|
return tuple_getitem;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
std::vector<TypeId> types;
|
||||||
|
std::vector<std::vector<size_t>> shapes;
|
||||||
|
std::vector<AnfNodePtr> make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(node); ++tuple_out_index) {
|
||||||
|
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index));
|
||||||
|
types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index));
|
||||||
|
shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index));
|
||||||
|
}
|
||||||
|
auto make_tuple = NewCNode(make_tuple_inputs_list);
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
|
||||||
|
return make_tuple;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!AnfAlgo::IsTupleOutput(node)) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
if (node->isa<Parameter>()) {
|
||||||
|
return TransParameterTuple(node->abstract());
|
||||||
|
} else if (node->isa<ValueNode>()) {
|
||||||
|
auto value_node = node->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
|
||||||
|
if (RemoveValueNodeFromGraph(value_node)) {
|
||||||
|
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
|
||||||
|
}
|
||||||
|
return make_tuple;
|
||||||
|
} else if (node->isa<CNode>()) {
|
||||||
|
return TransCNodeTuple(node->cast<CNodePtr>());
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Unexpected node:" << node->DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
|
const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
|
||||||
MS_EXCEPTION_IF_NULL(inputs_);
|
MS_EXCEPTION_IF_NULL(inputs_);
|
||||||
return *inputs_;
|
return *inputs_;
|
||||||
|
@ -782,6 +814,23 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) {
|
||||||
|
// update graph inputs
|
||||||
|
MS_EXCEPTION_IF_NULL(old_parameter);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_parameter);
|
||||||
|
if (old_parameter == new_parameter) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < inputs_->size(); i++) {
|
||||||
|
if ((*inputs_)[i] == old_parameter) {
|
||||||
|
MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString()
|
||||||
|
<< ",new graph input:" << new_parameter->DebugString();
|
||||||
|
(*inputs_)[i] = new_parameter;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) {
|
void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) {
|
||||||
MS_EXCEPTION_IF_NULL(inputs_);
|
MS_EXCEPTION_IF_NULL(inputs_);
|
||||||
{
|
{
|
||||||
|
@ -805,15 +854,7 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
|
||||||
output_cnode->set_input(i, new_anf_node);
|
output_cnode->set_input(i, new_anf_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// update graph inputs
|
ReplaceGraphInput(old_anf_node, new_anf_node);
|
||||||
for (size_t i = 0; i < inputs_->size(); i++) {
|
|
||||||
if ((*inputs_)[i] == old_anf_node.get()) {
|
|
||||||
MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString()
|
|
||||||
<< ",new graph input:" << new_anf_node->DebugString();
|
|
||||||
(*inputs_)[i] = new_anf_node.get();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// update front to backend map
|
// update front to backend map
|
||||||
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
|
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
|
||||||
|
|
|
@ -49,15 +49,17 @@ class KernelGraph : public FuncGraph {
|
||||||
|
|
||||||
const std::vector<AnfNodePtr> &inputs() const;
|
const std::vector<AnfNodePtr> &inputs() const;
|
||||||
std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); }
|
std::vector<AnfNodePtr> *MutableInputs() const { return inputs_.get(); }
|
||||||
|
void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
|
||||||
std::vector<AnfNodePtr> outputs() const;
|
std::vector<AnfNodePtr> outputs() const;
|
||||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
|
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
|
||||||
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
|
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
|
||||||
CNodePtr NewCNode(const CNodePtr &cnode);
|
CNodePtr NewCNode(const CNodePtr &cnode);
|
||||||
ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr);
|
ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr);
|
||||||
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
|
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
|
||||||
ValueNodePtr NewValueNode(const ValuePtr &value);
|
ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);
|
||||||
ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
|
ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
|
||||||
std::vector<AnfNodePtr> SplitTupleOutputNodeToNodeList(const AnfNodePtr &node);
|
// trans tuple output to maketuple + no_tuple out
|
||||||
|
AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node);
|
||||||
void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
|
void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
|
||||||
const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
|
const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
|
||||||
void SetExecOrderByDefault();
|
void SetExecOrderByDefault();
|
||||||
|
@ -167,8 +169,6 @@ class KernelGraph : public FuncGraph {
|
||||||
// remove value node form graph
|
// remove value node form graph
|
||||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||||
void SetKernelInfoForNode(const AnfNodePtr &node) const;
|
void SetKernelInfoForNode(const AnfNodePtr &node) const;
|
||||||
std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
|
|
||||||
std::vector<AnfNodePtr> SplitTupleParameterToNodeList(const ParameterPtr ¶meter);
|
|
||||||
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
|
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
|
||||||
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||||
std::unordered_set<AnfNodePtr> *visited_nodes);
|
std::unordered_set<AnfNodePtr> *visited_nodes);
|
||||||
|
@ -181,6 +181,10 @@ class KernelGraph : public FuncGraph {
|
||||||
bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||||
std::unordered_set<AnfNodePtr> *visited_nodes);
|
std::unordered_set<AnfNodePtr> *visited_nodes);
|
||||||
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
|
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
|
||||||
|
AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value);
|
||||||
|
AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract);
|
||||||
|
AnfNodePtr TransCNodeTuple(const CNodePtr &node);
|
||||||
|
AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx);
|
||||||
|
|
||||||
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
|
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
|
||||||
std::vector<AnfNodePtr> child_graph_result_;
|
std::vector<AnfNodePtr> child_graph_result_;
|
||||||
|
|
|
@ -99,13 +99,18 @@ TEST_F(TestHWConstInputToTensorInput, test_value_tuple_tensor_input) {
|
||||||
EXPECT_NE(ret->input(1)->cast<CNodePtr>(), nullptr);
|
EXPECT_NE(ret->input(1)->cast<CNodePtr>(), nullptr);
|
||||||
auto cnode = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
auto cnode = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||||
EXPECT_EQ(AnfAlgo::GetCNodeName(cnode), prim::kPrimDropoutGenMask->name());
|
EXPECT_EQ(AnfAlgo::GetCNodeName(cnode), prim::kPrimDropoutGenMask->name());
|
||||||
auto input1 = cnode->input(1);
|
std::vector<int> out;
|
||||||
ASSERT_TRUE(input1 != nullptr);
|
for (size_t i = 1; i <= 4; i++) {
|
||||||
EXPECT_TRUE(IsValueNode<tensor::Tensor>(input1));
|
auto input = cnode->input(i);
|
||||||
auto tensor = input1->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
|
ASSERT_TRUE(input != nullptr);
|
||||||
|
EXPECT_TRUE(IsValueNode<tensor::Tensor>(input));
|
||||||
|
auto tensor = input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
|
||||||
ASSERT_TRUE(tensor != nullptr);
|
ASSERT_TRUE(tensor != nullptr);
|
||||||
auto data = tensor->data_c();
|
int *data = (int *)(tensor->data_c());
|
||||||
EXPECT_EQ(std::vector<int>((int *)data, (int *)data + 4), std::vector<int>({2, 4, 2, 2}));
|
ASSERT_TRUE(data != nullptr);
|
||||||
|
out.push_back(*data);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(out, std::vector<int>({2, 4, 2, 2}));
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue