forked from mindspore-Ecosystem/mindspore
Do not get shape for monad type
This commit is contained in:
parent
9f05cc1351
commit
0aaa31764e
|
@ -78,12 +78,13 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
|||
}
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
auto input = GetRealInput(node_inputs[i]);
|
||||
|
||||
if (HasAbstractMonad(input)) {
|
||||
continue;
|
||||
}
|
||||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
is_parameter.push_back(ParameterRequireGrad(input_parameter));
|
||||
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || IsValueNode<tensor::Tensor>(input) ||
|
||||
IsValueNode<RefKey>(input)) {
|
||||
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
|
||||
is_parameter.push_back(false);
|
||||
}
|
||||
}
|
||||
|
@ -174,6 +175,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
|||
|
||||
// extract input element length
|
||||
for (auto &input : node_inputs) {
|
||||
if (HasAbstractMonad(input)) {
|
||||
continue;
|
||||
}
|
||||
if (IsValueNode<RefKey>(input)) {
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -182,8 +186,7 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
|||
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
|
||||
}
|
||||
inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
|
||||
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || input->isa<Parameter>() ||
|
||||
IsValueNode<tensor::Tensor>(input)) {
|
||||
} else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
|
||||
// extract input shape from parameter and apply node
|
||||
inputs_type_len.push_back(GetInputsTypeLen(input));
|
||||
}
|
||||
|
|
|
@ -1108,7 +1108,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (auto input : node->inputs()) {
|
||||
if (input->isa<CNode>() && HasAbstractMonad(input)) {
|
||||
if (HasAbstractMonad(input)) {
|
||||
node_size--;
|
||||
}
|
||||
}
|
||||
|
@ -1414,6 +1414,9 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
|
|||
for (size_t i = 1; i < inputs_size; ++i) {
|
||||
Shapes input_shapes;
|
||||
AnfNodePtr input = all_inputs[i];
|
||||
if (HasAbstractMonad(input)) {
|
||||
continue;
|
||||
}
|
||||
if (IsValueNode<RefKey>(input)) {
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -1424,8 +1427,7 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
|
|||
std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
|
||||
g_RefMap[parameters[0]] = node_pair;
|
||||
input_shapes = GetRefKeyNodeShape(input, func_graph);
|
||||
} else if ((input->isa<CNode>() && !HasAbstractMonad(input)) || IsValueNode<Tensor>(input) ||
|
||||
input->isa<Parameter>() ||
|
||||
} else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
|
||||
((IsValueNode<ValueList>(input) || IsValueNode<ValueTuple>(input)) && (inputs_size == 2))) {
|
||||
input_shapes = GetNodeShape(input);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue