forked from mindspore-Ecosystem/mindspore
!5641 check user define bprop when there is parameter in nested network
Merge pull request !5641 from zhangbuxue/check_user_define_bprop_when_there_is_parameter_in_nested_network
This commit is contained in:
commit
120ab80b23
|
@ -1182,6 +1182,11 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
||||||
resource_->manager()->AddFuncGraph(curr_g_);
|
resource_->manager()->AddFuncGraph(curr_g_);
|
||||||
// custom bprop debug
|
// custom bprop debug
|
||||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||||
|
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
|
||||||
|
if (par_number > 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
|
||||||
|
<< " parameters that is not supported in the net.";
|
||||||
|
}
|
||||||
MS_LOG(DEBUG) << "Use cell custom bprop function.";
|
MS_LOG(DEBUG) << "Use cell custom bprop function.";
|
||||||
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
|
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
|
||||||
if (bprop_graph != nullptr) {
|
if (bprop_graph != nullptr) {
|
||||||
|
|
|
@ -93,7 +93,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
||||||
for (size_t i = 0; i < grads.size(); i++) {
|
for (size_t i = 0; i < grads.size(); i++) {
|
||||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
||||||
if (!py::isinstance<tensor::Tensor>(grads[i])) {
|
if (!py::isinstance<tensor::Tensor>(grads[i])) {
|
||||||
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
|
MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
|
||||||
<< "th arg should be Tensor, but got "
|
<< "th arg should be Tensor, but got "
|
||||||
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
|
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
|
||||||
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";
|
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";
|
||||||
|
|
|
@ -472,7 +472,7 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
|
||||||
|
|
||||||
void FinalVM::SyncData(const py::object &arg) {
|
void FinalVM::SyncData(const py::object &arg) {
|
||||||
if (py::isinstance<py::tuple>(arg)) {
|
if (py::isinstance<py::tuple>(arg)) {
|
||||||
py::tuple arg_list = py::cast<py::tuple>(arg);
|
auto arg_list = py::cast<py::tuple>(arg);
|
||||||
for (size_t i = 0; i < arg_list.size(); i++) {
|
for (size_t i = 0; i < arg_list.size(); i++) {
|
||||||
SyncData(arg_list[i]);
|
SyncData(arg_list[i]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2464,12 +2464,7 @@ raise_set = [
|
||||||
('StridedSlice_1', {
|
('StridedSlice_1', {
|
||||||
'block': (P.StridedSlice(), {'exception': ValueError}),
|
'block': (P.StridedSlice(), {'exception': ValueError}),
|
||||||
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1)],
|
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1)],
|
||||||
'desc_inputs': [[4, 5, 6, 7]]}),
|
'desc_inputs': [[4, 5, 6, 7]]})
|
||||||
('StridedSlice_2', {
|
|
||||||
'block': (P.StridedSlice(), {'exception': ValueError}),
|
|
||||||
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1, 0)],
|
|
||||||
'desc_inputs': [[4, 5, 6, 7]]}),
|
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -177,7 +177,7 @@ def test_user_define_bprop_check_parameter():
|
||||||
grad_net = GradNet(net)
|
grad_net = GradNet(net)
|
||||||
with pytest.raises(RuntimeError) as ex:
|
with pytest.raises(RuntimeError) as ex:
|
||||||
ret = grad_net(x, sens)
|
ret = grad_net(x, sens)
|
||||||
assert "in scope Default does not support Parameter data type." in str(ex.value)
|
assert "When user defines the net bprop, there are 1 parameters that is not supported in the net." in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
def test_user_define_bprop_check_number():
|
def test_user_define_bprop_check_number():
|
||||||
|
|
Loading…
Reference in New Issue