forked from mindspore-Ecosystem/mindspore
fix_auto_parallel_find_loss_bug
This commit is contained in:
parent
77978d0921
commit
6dd2c75948
|
@ -1852,14 +1852,14 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
if (pre_cnode == nullptr) {
|
||||
return loss_node_info;
|
||||
}
|
||||
pre_cnode = HandleDependLoss(pre_cnode);
|
||||
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
// return -> cast
|
||||
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
||||
if (prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
||||
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
}
|
||||
pre_cnode = HandleDependLoss(pre_cnode);
|
||||
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
|
||||
// notice: the GetNext op has not input
|
||||
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
||||
|
@ -2416,6 +2416,12 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
|
|||
// shape op doesn't have params and attrs.
|
||||
OperatorParams params;
|
||||
OperatorAttrs attrs;
|
||||
auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto shape = shape_value->value();
|
||||
if (shape.empty()) {
|
||||
return;
|
||||
}
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(SHAPE_OP, args);
|
||||
InsertNode(op, node, 2, pre_node, root, "shape");
|
||||
|
|
Loading…
Reference in New Issue