white list for syntax exception

This commit is contained in:
zhangzhaoju 2021-11-25 17:10:56 +08:00
parent cdb618984f
commit d34f14e324
5 changed files with 57 additions and 28 deletions

View File

@ -2492,7 +2492,12 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
auto bprop_func_cellid = GetId(bprop_func);
bprop_cell_list_.emplace_back(bprop_func_cellid);
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
if (py::isinstance<Cell>(cell)) {
auto cell_ptr = py::cast<CellPtr>(cell);
fake_prim->set_bprop_cls_name(cell_ptr->name());
}
fake_prim->set_hook(bprop_func);
const auto &cell_id = GetCellId(cell, args);
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));

View File

@ -26,7 +26,8 @@ void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &ob
ValuePtr converted_ret = nullptr;
MS_EXCEPTION_IF_NULL(cell);
if (py::isinstance<py::module>(obj)) {
MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module";
MS_LOG(EXCEPTION) << "Cell set_attr failed, attr '" << name << "' should not be py::module '" << py::str(obj)
<< "'.";
}
bool converted = parse::ConvertData(obj, &converted_ret, true);
if (!converted) {

View File

@ -89,7 +89,7 @@ py::function PrimitivePy::GetBpropFunction() {
}
}
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) {
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {
py::tuple grads;
if (!py::isinstance<py::tuple>(grads_obj)) {
grads = py::make_tuple(grads_obj);
@ -98,15 +98,16 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
}
constexpr int filter_args_size = 2;
if (grads.size() != py_args.size() - filter_args_size) {
MS_EXCEPTION(TypeError) << "For user define net bprop, the gradients number: " << grads.size()
MS_EXCEPTION(TypeError) << "For user defined bprop of net '" << bprop_cls_name
<< "', the gradients number: " << grads.size()
<< " is not equal to the args number: " << (py_args.size() - filter_args_size) << ".";
}
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
for (size_t i = 0; i < grads.size(); i++) {
if (py::isinstance<tensor::Tensor>(py_args[i])) {
if (!py::isinstance<tensor::Tensor>(grads[i])) {
MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
<< "th arg should be Tensor, but got "
MS_EXCEPTION(ValueError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the "
<< i << "th arg should be Tensor, but got "
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";
}
@ -116,14 +117,14 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
py::tuple arg_shape = py_args[i].attr("shape");
py::tuple grad_shape = grads[i].attr("shape");
if (!grad_dtype.equal(arg_dtype)) {
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
MS_EXCEPTION(TypeError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the "
<< i << "th arg should have the same dtype as the " << i << "th arg, but the " << i
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
}
if (!grad_shape.equal(arg_shape)) {
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same shape as the " << i << "th arg, but the " << i
MS_EXCEPTION(ValueError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the "
<< i << "th arg should have the same shape as the " << i << "th arg, but the " << i
<< "th arg shape is: " << py::cast<py::str>(arg_shape)
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
}
@ -168,7 +169,10 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
if (!py::isinstance<tensor::Tensor>(grad_out)) {
hook_grad_.clear();
MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!";
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
}
auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
@ -176,8 +180,11 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
MS_EXCEPTION_IF_NULL(expected_out_tensor);
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
hook_grad_.clear();
MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be "
<< expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is "
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
MS_EXCEPTION(ValueError) << "The output type of " << py::str(co_name)
<< " is not consistent with the expected, it should be "
<< expected_out_tensor->GetShapeAndDataTypeInfo() << ", but got "
<< actual_out_tensor->GetShapeAndDataTypeInfo();
}
}
@ -199,7 +206,7 @@ BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
MS_LOG(DEBUG) << "Run bprop function start";
inst->NewGraph(hook_, input_args.cast<py::args>());
py::object grads_obj = hook_(*convert_args);
py::tuple grads = check_bprop_out(grads_obj, py_args);
py::tuple grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
MS_LOG(DEBUG) << "Run bprop function end";
return std::make_shared<PyObjectRef>(grads);
@ -222,7 +229,8 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
py::object name_obj = py::getattr(hook_, "__name__");
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported.";
}
py::tuple convert_args(input_param_nums - 1);
@ -252,7 +260,8 @@ BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
py::object name_obj = py::getattr(hook_, "__name__");
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported.";
}
constexpr size_t grad_output_index = 2;
@ -311,12 +320,13 @@ py::dict PrimitivePy::GetAttrDict() {
void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
if (!primitive->isa<PrimitivePy>()) {
MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!";
MS_LOG(EXCEPTION) << "Cannot copy a primitive which is not python primitive hook function to python primitive!";
}
auto primitive_py = primitive->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(primitive_py);
this->set_hook(primitive_py->hook());
if (primitive_py->HasAttr(kBpropAttrName)) {
set_bprop_cls_name(primitive_py->bprop_cls_name_);
(void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
}
}
@ -395,35 +405,46 @@ void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
std::string attr_name = name;
ValuePtr converted_ret = nullptr;
if (py::isinstance<py::module>(obj)) {
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_
<< "' failed, not support py::module be attribute, attr name:" << attr_name
<< " attr value:" << py::str(obj);
}
bool converted = parse::ConvertData(obj, &converted_ret);
if (!converted) {
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
MS_LOG(EXCEPTION) << "Attribute convert error for prim '" << this->name_ << "' attr name:" << attr_name
<< " attr value:" << py::str(obj)
<< " attr type:" << py::cast<std::string>(obj.attr("__class__").attr("__name__"));
}
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
attr_name = kOpAttrNameReplaceMap[attr_name];
}
(void)CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret);
attrs_[attr_name] = converted_ret;
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
(void)prim->AddAttr(attr_name, converted_ret);
}
(void)CheckAndConvertUtils::ConvertAttrValueToInt(this->name_, name, &converted_ret);
if (attr_name == "primitive_target") {
MS_EXCEPTION_IF_NULL(converted_ret);
if (!converted_ret->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_
<< "' failed, for attribute primitive_target only support string CPU|GPU|Ascend but got "
<< py::str(obj);
}
auto target = GetValue<std::string>(converted_ret);
if (!target.empty() && target != kCPUDevice && target != kGPUDevice && target != kAscendDevice &&
target != "Device") {
MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_
<< "' failed, for attribute only support string CPU|GPU|Ascend|Device but got "
<< py::str(obj);
}
if (target != kCPUDevice && target != kGPUDevice) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
context_ptr->set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true);
}
}
attrs_[attr_name] = converted_ret;
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
(void)prim->AddAttr(attr_name, converted_ret);
}
}
void PrimitivePyAdapter::DelPyAttr(const py::str &name) {

View File

@ -74,6 +74,7 @@ class PrimitivePy : public Primitive {
bool HasPyObj() { return python_obj_.operator bool(); }
PrimitivePtr Clone() override;
PrimitivePyAdapterPtr adapter() const { return adapter_; }
void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; }
private:
py::function GetComputeFunction() const;
@ -82,6 +83,7 @@ class PrimitivePy : public Primitive {
py::object python_obj_;
PrimitivePyAdapterPtr adapter_;
py::function hook_;
std::string bprop_cls_name_;
std::vector<Signature> signatures_;
static std::map<std::string, py::object> hook_grad_;
};

View File

@ -209,4 +209,4 @@ def test_user_define_bprop_check_number():
grad_net = GradNet(net)
with pytest.raises(TypeError) as ex:
ret = grad_net(x, y, sens)
assert "For user define net bprop, the gradients number: 1 is not equal to the args number: 2." in str(ex.value)
assert "For user defined bprop of net" in str(ex.value)