Do not get shape for monad type

This commit is contained in:
Margaret_wangrui 2021-02-19 15:08:11 +08:00
parent 9f05cc1351
commit 0aaa31764e
2 changed files with 13 additions and 8 deletions

View File

@ -78,12 +78,13 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
}
for (size_t i = 1; i < node_inputs.size(); ++i) {
auto input = GetRealInput(node_inputs[i]);
if (HasAbstractMonad(input)) {
continue;
}
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
is_parameter.push_back(ParameterRequireGrad(input_parameter));
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || IsValueNode<tensor::Tensor>(input) ||
IsValueNode<RefKey>(input)) {
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
is_parameter.push_back(false);
}
}
@ -174,6 +175,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
// extract input element length
for (auto &input : node_inputs) {
if (HasAbstractMonad(input)) {
continue;
}
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -182,8 +186,7 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
}
inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || input->isa<Parameter>() ||
IsValueNode<tensor::Tensor>(input)) {
} else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
// extract input shape from parameter and apply node
inputs_type_len.push_back(GetInputsTypeLen(input));
}

View File

@ -1108,7 +1108,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (auto input : node->inputs()) {
if (input->isa<CNode>() && HasAbstractMonad(input)) {
if (HasAbstractMonad(input)) {
node_size--;
}
}
@ -1414,6 +1414,9 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
for (size_t i = 1; i < inputs_size; ++i) {
Shapes input_shapes;
AnfNodePtr input = all_inputs[i];
if (HasAbstractMonad(input)) {
continue;
}
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -1424,8 +1427,7 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
g_RefMap[parameters[0]] = node_pair;
input_shapes = GetRefKeyNodeShape(input, func_graph);
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || IsValueNode<Tensor>(input) ||
input->isa<Parameter>() ||
} else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
((IsValueNode<ValueList>(input) || IsValueNode<ValueTuple>(input)) && (inputs_size == 2))) {
input_shapes = GetNodeShape(input);
} else {