!13350 fix dynamic input when the input type is tuple but not maketuple

From: @lianliguang
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-03-17 09:03:40 +08:00 committed by Gitee
commit ccdccc9b2f
2 changed files with 33 additions and 19 deletions

View File

@ -27,6 +27,34 @@
namespace mindspore {
namespace opt {
namespace {
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs) {
if (!AnfAlgo::IsTupleOutput(tuple_input)) {
auto abs = tuple_input->abstract();
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
return -1;
}
MS_EXCEPTION_IF_NULL(plant_inputs);
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
if (tuple_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
auto make_tuple = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t j = 0; j < tuple_input_num; ++j) {
// using for graph kernel
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
plant_inputs->emplace_back(dyn_input_node);
}
return input_size;
}
for (size_t index = 0; index < input_size; ++index) {
auto dyn_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
plant_inputs->emplace_back(dyn_input_node);
}
return input_size;
}
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph);
@ -41,25 +69,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
for (size_t i = 0; i < input_num; ++i) {
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
auto input_size = AnfAlgo::GetOutputTensorNum(input_node);
dyn_input_sizes.push_back(input_size);
auto make_tuple = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t j = 0; j < tuple_input_num; ++j) {
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>());
if (!success) {
MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString();
}
}
plant_inputs.push_back(dyn_input_node);
}
if (AnfAlgo::IsTupleOutput(input_node)) {
dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
} else {
dyn_input_sizes.push_back(-1);
plant_inputs.push_back(input_node);

View File

@ -1156,6 +1156,9 @@ uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
MS_EXCEPTION_IF_NULL(anf);
TypePtr type = anf->Type();
if (type == nullptr) {
return false;
}
MS_EXCEPTION_IF_NULL(type);
return type->isa<Tuple>();
}