forked from mindspore-Ecosystem/mindspore
!5641 check user define bprop when there is parameter in nested network
Merge pull request !5641 from zhangbuxue/check_user_define_bprop_when_there_is_parameter_in_nested_network
This commit is contained in:
commit
120ab80b23
|
@ -1182,6 +1182,11 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
resource_->manager()->AddFuncGraph(curr_g_);
|
||||
// custom bprop debug
|
||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
|
||||
if (par_number > 0) {
|
||||
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
|
||||
<< " parameters that is not supported in the net.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Use cell custom bprop function.";
|
||||
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
|
||||
if (bprop_graph != nullptr) {
|
||||
|
|
|
@ -93,7 +93,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|||
for (size_t i = 0; i < grads.size(); i++) {
|
||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
||||
if (!py::isinstance<tensor::Tensor>(grads[i])) {
|
||||
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
|
||||
MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
|
||||
<< "th arg should be Tensor, but got "
|
||||
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
|
||||
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";
|
||||
|
|
|
@ -472,7 +472,7 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
|
|||
|
||||
void FinalVM::SyncData(const py::object &arg) {
|
||||
if (py::isinstance<py::tuple>(arg)) {
|
||||
py::tuple arg_list = py::cast<py::tuple>(arg);
|
||||
auto arg_list = py::cast<py::tuple>(arg);
|
||||
for (size_t i = 0; i < arg_list.size(); i++) {
|
||||
SyncData(arg_list[i]);
|
||||
}
|
||||
|
|
|
@ -2464,12 +2464,7 @@ raise_set = [
|
|||
('StridedSlice_1', {
|
||||
'block': (P.StridedSlice(), {'exception': ValueError}),
|
||||
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1)],
|
||||
'desc_inputs': [[4, 5, 6, 7]]}),
|
||||
('StridedSlice_2', {
|
||||
'block': (P.StridedSlice(), {'exception': ValueError}),
|
||||
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1, 0)],
|
||||
'desc_inputs': [[4, 5, 6, 7]]}),
|
||||
|
||||
'desc_inputs': [[4, 5, 6, 7]]})
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ def test_user_define_bprop_check_parameter():
|
|||
grad_net = GradNet(net)
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
ret = grad_net(x, sens)
|
||||
assert "in scope Default does not support Parameter data type." in str(ex.value)
|
||||
assert "When user defines the net bprop, there are 1 parameters that is not supported in the net." in str(ex.value)
|
||||
|
||||
|
||||
def test_user_define_bprop_check_number():
|
||||
|
|
Loading…
Reference in New Issue