!5641 check user define bprop when there is parameter in nested network

Merge pull request !5641 from zhangbuxue/check_user_define_bprop_when_there_is_parameter_in_nested_network
This commit is contained in:
mindspore-ci-bot 2020-09-02 09:04:05 +08:00 committed by Gitee
commit 120ab80b23
5 changed files with 9 additions and 9 deletions

View File

@ -1182,6 +1182,11 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
resource_->manager()->AddFuncGraph(curr_g_);
// custom bprop debug
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
<< " parameters that is not supported in the net.";
}
MS_LOG(DEBUG) << "Use cell custom bprop function.";
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
if (bprop_graph != nullptr) {

View File

@ -93,7 +93,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
for (size_t i = 0; i < grads.size(); i++) {
if (py::isinstance<tensor::Tensor>(py_args[i])) {
if (!py::isinstance<tensor::Tensor>(grads[i])) {
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
<< "th arg should be Tensor, but got "
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";

View File

@ -472,7 +472,7 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
void FinalVM::SyncData(const py::object &arg) {
if (py::isinstance<py::tuple>(arg)) {
py::tuple arg_list = py::cast<py::tuple>(arg);
auto arg_list = py::cast<py::tuple>(arg);
for (size_t i = 0; i < arg_list.size(); i++) {
SyncData(arg_list[i]);
}

View File

@ -2464,12 +2464,7 @@ raise_set = [
('StridedSlice_1', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1)],
'desc_inputs': [[4, 5, 6, 7]]}),
('StridedSlice_2', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1, 0)],
'desc_inputs': [[4, 5, 6, 7]]}),
'desc_inputs': [[4, 5, 6, 7]]})
]

View File

@ -177,7 +177,7 @@ def test_user_define_bprop_check_parameter():
grad_net = GradNet(net)
with pytest.raises(RuntimeError) as ex:
ret = grad_net(x, sens)
assert "in scope Default does not support Parameter data type." in str(ex.value)
assert "When user defines the net bprop, there are 1 parameters that is not supported in the net." in str(ex.value)
def test_user_define_bprop_check_number():