forked from OSSInnovation/mindspore
!3568 support bprop for const in pynative and develop stridedslice and isinstance
Merge pull request !3568 from zhangbuxue/support_bprop_for_const_in_pynative_mode_and_develop_stridedslice
This commit is contained in:
commit
e8aa46af55
|
@ -133,7 +133,9 @@ def while_cond(x):
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_type_same(x_type, base_type):
|
def check_type_same(x_type, base_type):
|
||||||
"""Check x_type is same as base_type."""
|
"""Check x_type is same as base_type."""
|
||||||
return mstype.issubclass_(x_type, base_type)
|
if mstype.issubclass_(x_type, base_type):
|
||||||
|
return True
|
||||||
|
raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
|
@ -964,8 +964,8 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o
|
||||||
cnode->set_inputs(args);
|
cnode->set_inputs(args);
|
||||||
set_obj_node_map(curr_g_, out_id, cnode);
|
set_obj_node_map(curr_g_, out_id, cnode);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Graph has no this out: " << out_id;
|
MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
|
||||||
return;
|
MakeValueNode(out, out_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EndGraphByOutId(out_id, cell, out, args);
|
EndGraphByOutId(out_id, cell, out, args);
|
||||||
|
|
|
@ -1959,7 +1959,7 @@ def _compute_slicing_length(begin, end, stride, x_shape, i):
|
||||||
if begin >= x_dim:
|
if begin >= x_dim:
|
||||||
# When slicing backward, if begin >= x_dim, set begin = -1, which means start from the last element.
|
# When slicing backward, if begin >= x_dim, set begin = -1, which means start from the last element.
|
||||||
begin = -1
|
begin = -1
|
||||||
if 0 < end < x_dim:
|
if 0 <= end < x_dim:
|
||||||
end += -x_dim
|
end += -x_dim
|
||||||
if end < -x_dim - 1:
|
if end < -x_dim - 1:
|
||||||
# When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means
|
# When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means
|
||||||
|
|
Loading…
Reference in New Issue