forked from mindspore-Ecosystem/mindspore
!7022 [AutoParallel]Fix find loss and root reshape bug
Merge pull request !7022 from lichen/fix_auto_parallel_find_loss_bug
This commit is contained in:
commit
b9df01b60e
|
@ -1852,14 +1852,14 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||||
if (pre_cnode == nullptr) {
|
if (pre_cnode == nullptr) {
|
||||||
return loss_node_info;
|
return loss_node_info;
|
||||||
}
|
}
|
||||||
pre_cnode = HandleDependLoss(pre_cnode);
|
auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||||
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
||||||
// return -> cast
|
// return -> cast
|
||||||
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
if (prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
||||||
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||||
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
||||||
}
|
}
|
||||||
|
pre_cnode = HandleDependLoss(pre_cnode);
|
||||||
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||||
|
|
||||||
// notice: the GetNext op has not input
|
// notice: the GetNext op has not input
|
||||||
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
||||||
|
@ -2416,6 +2416,12 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
|
||||||
// shape op doesn't have params and attrs.
|
// shape op doesn't have params and attrs.
|
||||||
OperatorParams params;
|
OperatorParams params;
|
||||||
OperatorAttrs attrs;
|
OperatorAttrs attrs;
|
||||||
|
auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_value);
|
||||||
|
auto shape = shape_value->value();
|
||||||
|
if (shape.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
OperatorArgs args = std::make_pair(attrs, params);
|
OperatorArgs args = std::make_pair(attrs, params);
|
||||||
Operator op = std::make_pair(SHAPE_OP, args);
|
Operator op = std::make_pair(SHAPE_OP, args);
|
||||||
InsertNode(op, node, 2, pre_node, root, "shape");
|
InsertNode(op, node, 2, pre_node, root, "shape");
|
||||||
|
|
Loading…
Reference in New Issue