diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index af1d67a6fdb..82d254d2fc9 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -1136,10 +1136,31 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob AnfNodePtr target_node = ParseExprNode(block, targ); MS_EXCEPTION_IF_NULL(target_node); + std::string attr_name = targ.attr("attr").cast(); std::string var_name = "self."; - (void)var_name.append(targ.attr("attr").cast()); + (void)var_name.append(attr_name); MS_LOG(DEBUG) << "assign " << var_name; + // Get targ location info for error printing + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type + if (!py::hasattr(ast()->obj(), attr_name.c_str())) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" + << line_no; + } + auto obj = ast()->obj().attr(attr_name.c_str()); + auto obj_type = obj.attr("__class__").attr("__name__"); + if (!py::hasattr(obj, "__parameter__")) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" + << py::str(obj).cast() << "' with type '" + << py::str(obj_type).cast() << "' at " << filename << ":" << line_no; + } + MS_EXCEPTION_IF_NULL(block); block->WriteVariable(var_name, assigned_node); MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); diff --git a/tests/ut/python/pynative_mode/test_insert_grad_of.py b/tests/ut/python/pynative_mode/test_insert_grad_of.py index 0527365a983..d9368f315b8 100644 --- a/tests/ut/python/pynative_mode/test_insert_grad_of.py +++ b/tests/ut/python/pynative_mode/test_insert_grad_of.py @@ -124,9 +124,9 @@ def test_cell_assign(): class Mul(nn.Cell): def __init__(self): super(Mul, self).__init__() - self.get_g = P.InsertGradientOf(self.save_gradient) self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w") self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g") + self.get_g = P.InsertGradientOf(self.save_gradient) def save_gradient(self, dout): self.matrix_g = dout + self.matrix_g