From d34f14e3245fe3b3a6bef9f31a85ad8891a7d495 Mon Sep 17 00:00:00 2001 From: zhangzhaoju Date: Thu, 25 Nov 2021 17:10:56 +0800 Subject: [PATCH] white list for syntax exception --- .../pipeline/pynative/pynative_execute.cc | 5 ++ mindspore/ccsrc/pybind_api/ir/cell_py.cc | 3 +- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 73 ++++++++++++------- mindspore/ccsrc/pybind_api/ir/primitive_py.h | 2 + .../test_user_define_bprop_check.py | 2 +- 5 files changed, 57 insertions(+), 28 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 76a32c32bdc..f56ac42e5fa 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -2492,7 +2492,12 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object auto bprop_func_cellid = GetId(bprop_func); bprop_cell_list_.emplace_back(bprop_func_cellid); auto fake_prim = std::make_shared(prim::kPrimHookBackward->name()); + if (py::isinstance(cell)) { + auto cell_ptr = py::cast(cell); + fake_prim->set_bprop_cls_name(cell_ptr->name()); + } fake_prim->set_hook(bprop_func); + const auto &cell_id = GetCellId(cell, args); (void)fake_prim->AddAttr("cell_id", MakeValue(cell_id)); (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true)); diff --git a/mindspore/ccsrc/pybind_api/ir/cell_py.cc b/mindspore/ccsrc/pybind_api/ir/cell_py.cc index b34471a7144..149454376ab 100644 --- a/mindspore/ccsrc/pybind_api/ir/cell_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/cell_py.cc @@ -26,7 +26,8 @@ void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &ob ValuePtr converted_ret = nullptr; MS_EXCEPTION_IF_NULL(cell); if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module"; + MS_LOG(EXCEPTION) << "Cell set_attr failed, attr '" << name << "' should not be py::module '" << py::str(obj) + << "'."; } bool converted = parse::ConvertData(obj, &converted_ret, true); if (!converted) { diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index f949045fa98..a3ad657cfb2 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -89,7 +89,7 @@ py::function PrimitivePy::GetBpropFunction() { } } -py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) { +py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) { py::tuple grads; if (!py::isinstance(grads_obj)) { grads = py::make_tuple(grads_obj); @@ -98,15 +98,16 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) } constexpr int filter_args_size = 2; if (grads.size() != py_args.size() - filter_args_size) { - MS_EXCEPTION(TypeError) << "For user define net bprop, the gradients number: " << grads.size() + MS_EXCEPTION(TypeError) << "For user defined bprop of net '" << bprop_cls_name + << "', the gradients number: " << grads.size() << " is not equal to the args number: " << (py_args.size() - filter_args_size) << "."; } if (MsContext::GetInstance()->get_param(MS_CTX_CHECK_BPROP_FLAG)) { for (size_t i = 0; i < grads.size(); i++) { if (py::isinstance(py_args[i])) { if (!py::isinstance(grads[i])) { - MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i - << "th arg should be Tensor, but got " + MS_EXCEPTION(ValueError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the " + << i << "th arg should be Tensor, but got " << py::cast(grads[i].attr("__class__").attr("__name__")) << ", and the value is " << py::cast(grads[i]) << "."; } @@ -116,14 +117,14 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) py::tuple arg_shape = py_args[i].attr("shape"); py::tuple grad_shape = grads[i].attr("shape"); if (!grad_dtype.equal(arg_dtype)) { - MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i - << "th arg should have the same dtype as the " << i << "th arg, but the " << i + MS_EXCEPTION(TypeError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the " + << i << "th arg should have the same dtype as the " << i << "th arg, but the " << i << "th arg dtype is: " << py::cast(arg_dtype) << ", the gradient dtype is: " << py::cast(grad_dtype) << "."; } if (!grad_shape.equal(arg_shape)) { - MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i - << "th arg should have the same shape as the " << i << "th arg, but the " << i + MS_EXCEPTION(ValueError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the " + << i << "th arg should have the same shape as the " << i << "th arg, but the " << i << "th arg shape is: " << py::cast(arg_shape) << ", the gradient shape is: " << py::cast(grad_shape) << "."; } @@ -168,7 +169,10 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj if (py::isinstance(expected_grad_out)) { if (!py::isinstance(grad_out)) { hook_grad_.clear(); - MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!"; + py::object code_obj = py::getattr(hook_, "__code__"); + py::object co_name = py::getattr(code_obj, "co_name"); + MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got " + << py::cast(grad_out.attr("__class__").attr("__name__")) << "."; } auto actual_out_tensor = py::cast(grad_out); auto expected_out_tensor = py::cast(expected_grad_out); @@ -176,8 +180,11 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj MS_EXCEPTION_IF_NULL(expected_out_tensor); if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) { hook_grad_.clear(); - MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be " - << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is " + py::object code_obj = py::getattr(hook_, "__code__"); + py::object co_name = py::getattr(code_obj, "co_name"); + MS_EXCEPTION(ValueError) << "The output type of " << py::str(co_name) + << " is not consistent with the expected, it should be " + << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but got " << actual_out_tensor->GetShapeAndDataTypeInfo(); } } @@ -199,7 +206,7 @@ BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const { MS_LOG(DEBUG) << "Run bprop function start"; inst->NewGraph(hook_, input_args.cast()); py::object grads_obj = hook_(*convert_args); - py::tuple grads = check_bprop_out(grads_obj, py_args); + py::tuple grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_); inst->EndGraph(hook_, grads_obj, input_args.cast()); MS_LOG(DEBUG) << "Run bprop function end"; return std::make_shared(grads); @@ -222,7 +229,8 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const { py::object code_obj = py::getattr(hook_, "__code__"); py::object co_name = py::getattr(code_obj, "co_name"); if (std::string(py::str(co_name)) == "staging_specialize") { - MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported."; + py::object name_obj = py::getattr(hook_, "__name__"); + MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported."; } py::tuple convert_args(input_param_nums - 1); @@ -252,7 +260,8 @@ BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const { py::object code_obj = py::getattr(hook_, "__code__"); py::object co_name = py::getattr(code_obj, "co_name"); if (std::string(py::str(co_name)) == "staging_specialize") { - MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported."; + py::object name_obj = py::getattr(hook_, "__name__"); + MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported."; } constexpr size_t grad_output_index = 2; @@ -311,12 +320,13 @@ py::dict PrimitivePy::GetAttrDict() { void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { MS_EXCEPTION_IF_NULL(primitive); if (!primitive->isa()) { - MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; + MS_LOG(EXCEPTION) << "Cannot copy a primitive which is not python primitive hook function to python primitive!"; } auto primitive_py = primitive->cast(); MS_EXCEPTION_IF_NULL(primitive_py); this->set_hook(primitive_py->hook()); if (primitive_py->HasAttr(kBpropAttrName)) { + set_bprop_cls_name(primitive_py->bprop_cls_name_); (void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName)); } } @@ -395,35 +405,46 @@ void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) { std::string attr_name = name; ValuePtr converted_ret = nullptr; if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; + MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_ + << "' failed, not support py::module be attribute, attr name:" << attr_name + << " attr value:" << py::str(obj); } bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); + MS_LOG(EXCEPTION) << "Attribute convert error for prim '" << this->name_ << "' attr name:" << attr_name + << " attr value:" << py::str(obj) + << " attr type:" << py::cast(obj.attr("__class__").attr("__name__")); } if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) { attr_name = kOpAttrNameReplaceMap[attr_name]; } - (void)CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret); - attrs_[attr_name] = converted_ret; - auto prim = attached_primitive_.lock(); - if (prim != nullptr) { - (void)prim->AddAttr(attr_name, converted_ret); - } - + (void)CheckAndConvertUtils::ConvertAttrValueToInt(this->name_, name, &converted_ret); if (attr_name == "primitive_target") { MS_EXCEPTION_IF_NULL(converted_ret); if (!converted_ret->isa()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_ + << "' failed, for attribute primitive_target only support string CPU|GPU|Ascend but got " + << py::str(obj); } - auto target = GetValue(converted_ret); + if (!target.empty() && target != kCPUDevice && target != kGPUDevice && target != kAscendDevice && + target != "Device") { + MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_ + << "' failed, for attribute only support string CPU|GPU|Ascend|Device but got " + << py::str(obj); + } if (target != kCPUDevice && target != kGPUDevice) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); context_ptr->set_param(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true); } } + + attrs_[attr_name] = converted_ret; + auto prim = attached_primitive_.lock(); + if (prim != nullptr) { + (void)prim->AddAttr(attr_name, converted_ret); + } } void PrimitivePyAdapter::DelPyAttr(const py::str &name) { diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index 4be06d5ad29..ab0d2f38239 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -74,6 +74,7 @@ class PrimitivePy : public Primitive { bool HasPyObj() { return python_obj_.operator bool(); } PrimitivePtr Clone() override; PrimitivePyAdapterPtr adapter() const { return adapter_; } + void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; } private: py::function GetComputeFunction() const; @@ -82,6 +83,7 @@ class PrimitivePy : public Primitive { py::object python_obj_; PrimitivePyAdapterPtr adapter_; py::function hook_; + std::string bprop_cls_name_; std::vector signatures_; static std::map hook_grad_; }; diff --git a/tests/ut/python/pynative_mode/test_user_define_bprop_check.py b/tests/ut/python/pynative_mode/test_user_define_bprop_check.py index 9e1581ca120..006e1af8d8b 100644 --- a/tests/ut/python/pynative_mode/test_user_define_bprop_check.py +++ b/tests/ut/python/pynative_mode/test_user_define_bprop_check.py @@ -209,4 +209,4 @@ def test_user_define_bprop_check_number(): grad_net = GradNet(net) with pytest.raises(TypeError) as ex: ret = grad_net(x, y, sens) - assert "For user define net bprop, the gradients number: 1 is not equal to the args number: 2." in str(ex.value) + assert "For user defined bprop of net" in str(ex.value)