forked from mindspore-Ecosystem/mindspore
!13350 fix dynamic input when the input type is tuple but not maketuple
From: @lianliguang Reviewed-by: @kisnwang,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
ccdccc9b2f
|
@ -27,6 +27,34 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs) {
|
||||
if (!AnfAlgo::IsTupleOutput(tuple_input)) {
|
||||
auto abs = tuple_input->abstract();
|
||||
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
|
||||
return -1;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(plant_inputs);
|
||||
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
|
||||
if (tuple_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = tuple_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||
// using for graph kernel
|
||||
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
for (size_t index = 0; index < input_size; ++index) {
|
||||
auto dyn_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
|
||||
plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
|
||||
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -41,25 +69,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
|
|||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
|
||||
auto input_size = AnfAlgo::GetOutputTensorNum(input_node);
|
||||
dyn_input_sizes.push_back(input_size);
|
||||
auto make_tuple = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>());
|
||||
if (!success) {
|
||||
MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString();
|
||||
}
|
||||
}
|
||||
plant_inputs.push_back(dyn_input_node);
|
||||
}
|
||||
if (AnfAlgo::IsTupleOutput(input_node)) {
|
||||
dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
||||
} else {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
plant_inputs.push_back(input_node);
|
||||
|
|
|
@ -1156,6 +1156,9 @@ uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
|
|||
bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
TypePtr type = anf->Type();
|
||||
if (type == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
return type->isa<Tuple>();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue