forked from mindspore-Ecosystem/mindspore
white list for syntax exception
This commit is contained in:
parent
cdb618984f
commit
d34f14e324
|
@ -2492,7 +2492,12 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
|
|||
auto bprop_func_cellid = GetId(bprop_func);
|
||||
bprop_cell_list_.emplace_back(bprop_func_cellid);
|
||||
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
|
||||
if (py::isinstance<Cell>(cell)) {
|
||||
auto cell_ptr = py::cast<CellPtr>(cell);
|
||||
fake_prim->set_bprop_cls_name(cell_ptr->name());
|
||||
}
|
||||
fake_prim->set_hook(bprop_func);
|
||||
|
||||
const auto &cell_id = GetCellId(cell, args);
|
||||
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
|
||||
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
|
||||
|
|
|
@ -26,7 +26,8 @@ void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &ob
|
|||
ValuePtr converted_ret = nullptr;
|
||||
MS_EXCEPTION_IF_NULL(cell);
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module";
|
||||
MS_LOG(EXCEPTION) << "Cell set_attr failed, attr '" << name << "' should not be py::module '" << py::str(obj)
|
||||
<< "'.";
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||
if (!converted) {
|
||||
|
|
|
@ -89,7 +89,7 @@ py::function PrimitivePy::GetBpropFunction() {
|
|||
}
|
||||
}
|
||||
|
||||
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) {
|
||||
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {
|
||||
py::tuple grads;
|
||||
if (!py::isinstance<py::tuple>(grads_obj)) {
|
||||
grads = py::make_tuple(grads_obj);
|
||||
|
@ -98,15 +98,16 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|||
}
|
||||
constexpr int filter_args_size = 2;
|
||||
if (grads.size() != py_args.size() - filter_args_size) {
|
||||
MS_EXCEPTION(TypeError) << "For user define net bprop, the gradients number: " << grads.size()
|
||||
MS_EXCEPTION(TypeError) << "For user defined bprop of net '" << bprop_cls_name
|
||||
<< "', the gradients number: " << grads.size()
|
||||
<< " is not equal to the args number: " << (py_args.size() - filter_args_size) << ".";
|
||||
}
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
|
||||
for (size_t i = 0; i < grads.size(); i++) {
|
||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
||||
if (!py::isinstance<tensor::Tensor>(grads[i])) {
|
||||
MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
|
||||
<< "th arg should be Tensor, but got "
|
||||
MS_EXCEPTION(ValueError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the "
|
||||
<< i << "th arg should be Tensor, but got "
|
||||
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
|
||||
<< ", and the value is " << py::cast<py::str>(grads[i]) << ".";
|
||||
}
|
||||
|
@ -116,14 +117,14 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|||
py::tuple arg_shape = py_args[i].attr("shape");
|
||||
py::tuple grad_shape = grads[i].attr("shape");
|
||||
if (!grad_dtype.equal(arg_dtype)) {
|
||||
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
|
||||
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
|
||||
MS_EXCEPTION(TypeError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the "
|
||||
<< i << "th arg should have the same dtype as the " << i << "th arg, but the " << i
|
||||
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
|
||||
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
|
||||
}
|
||||
if (!grad_shape.equal(arg_shape)) {
|
||||
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
|
||||
<< "th arg should have the same shape as the " << i << "th arg, but the " << i
|
||||
MS_EXCEPTION(ValueError) << "For user defined bprop of net '" << bprop_cls_name << "', the gradient of the "
|
||||
<< i << "th arg should have the same shape as the " << i << "th arg, but the " << i
|
||||
<< "th arg shape is: " << py::cast<py::str>(arg_shape)
|
||||
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
|
||||
}
|
||||
|
@ -168,7 +169,10 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
|
|||
if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
|
||||
if (!py::isinstance<tensor::Tensor>(grad_out)) {
|
||||
hook_grad_.clear();
|
||||
MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!";
|
||||
py::object code_obj = py::getattr(hook_, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
|
||||
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
|
||||
}
|
||||
auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
|
||||
auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
|
||||
|
@ -176,8 +180,11 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
|
|||
MS_EXCEPTION_IF_NULL(expected_out_tensor);
|
||||
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
|
||||
hook_grad_.clear();
|
||||
MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be "
|
||||
<< expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is "
|
||||
py::object code_obj = py::getattr(hook_, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
MS_EXCEPTION(ValueError) << "The output type of " << py::str(co_name)
|
||||
<< " is not consistent with the expected, it should be "
|
||||
<< expected_out_tensor->GetShapeAndDataTypeInfo() << ", but got "
|
||||
<< actual_out_tensor->GetShapeAndDataTypeInfo();
|
||||
}
|
||||
}
|
||||
|
@ -199,7 +206,7 @@ BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
|
|||
MS_LOG(DEBUG) << "Run bprop function start";
|
||||
inst->NewGraph(hook_, input_args.cast<py::args>());
|
||||
py::object grads_obj = hook_(*convert_args);
|
||||
py::tuple grads = check_bprop_out(grads_obj, py_args);
|
||||
py::tuple grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
|
||||
inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
|
||||
MS_LOG(DEBUG) << "Run bprop function end";
|
||||
return std::make_shared<PyObjectRef>(grads);
|
||||
|
@ -222,7 +229,8 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
|
|||
py::object code_obj = py::getattr(hook_, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
|
||||
py::object name_obj = py::getattr(hook_, "__name__");
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported.";
|
||||
}
|
||||
|
||||
py::tuple convert_args(input_param_nums - 1);
|
||||
|
@ -252,7 +260,8 @@ BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
|
|||
py::object code_obj = py::getattr(hook_, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
|
||||
py::object name_obj = py::getattr(hook_, "__name__");
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@ms_function' is not supported.";
|
||||
}
|
||||
|
||||
constexpr size_t grad_output_index = 2;
|
||||
|
@ -311,12 +320,13 @@ py::dict PrimitivePy::GetAttrDict() {
|
|||
void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
if (!primitive->isa<PrimitivePy>()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!";
|
||||
MS_LOG(EXCEPTION) << "Cannot copy a primitive which is not python primitive hook function to python primitive!";
|
||||
}
|
||||
auto primitive_py = primitive->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(primitive_py);
|
||||
this->set_hook(primitive_py->hook());
|
||||
if (primitive_py->HasAttr(kBpropAttrName)) {
|
||||
set_bprop_cls_name(primitive_py->bprop_cls_name_);
|
||||
(void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
|
||||
}
|
||||
}
|
||||
|
@ -395,35 +405,46 @@ void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
|
|||
std::string attr_name = name;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_
|
||||
<< "' failed, not support py::module be attribute, attr name:" << attr_name
|
||||
<< " attr value:" << py::str(obj);
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error for prim '" << this->name_ << "' attr name:" << attr_name
|
||||
<< " attr value:" << py::str(obj)
|
||||
<< " attr type:" << py::cast<std::string>(obj.attr("__class__").attr("__name__"));
|
||||
}
|
||||
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
|
||||
attr_name = kOpAttrNameReplaceMap[attr_name];
|
||||
}
|
||||
(void)CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret);
|
||||
attrs_[attr_name] = converted_ret;
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
(void)prim->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
|
||||
(void)CheckAndConvertUtils::ConvertAttrValueToInt(this->name_, name, &converted_ret);
|
||||
if (attr_name == "primitive_target") {
|
||||
MS_EXCEPTION_IF_NULL(converted_ret);
|
||||
if (!converted_ret->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_
|
||||
<< "' failed, for attribute primitive_target only support string CPU|GPU|Ascend but got "
|
||||
<< py::str(obj);
|
||||
}
|
||||
|
||||
auto target = GetValue<std::string>(converted_ret);
|
||||
if (!target.empty() && target != kCPUDevice && target != kGPUDevice && target != kAscendDevice &&
|
||||
target != "Device") {
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr for prim '" << this->name_
|
||||
<< "' failed, for attribute only support string CPU|GPU|Ascend|Device but got "
|
||||
<< py::str(obj);
|
||||
}
|
||||
if (target != kCPUDevice && target != kGPUDevice) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
context_ptr->set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true);
|
||||
}
|
||||
}
|
||||
|
||||
attrs_[attr_name] = converted_ret;
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
(void)prim->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
}
|
||||
|
||||
void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
|
||||
|
|
|
@ -74,6 +74,7 @@ class PrimitivePy : public Primitive {
|
|||
bool HasPyObj() { return python_obj_.operator bool(); }
|
||||
PrimitivePtr Clone() override;
|
||||
PrimitivePyAdapterPtr adapter() const { return adapter_; }
|
||||
void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; }
|
||||
|
||||
private:
|
||||
py::function GetComputeFunction() const;
|
||||
|
@ -82,6 +83,7 @@ class PrimitivePy : public Primitive {
|
|||
py::object python_obj_;
|
||||
PrimitivePyAdapterPtr adapter_;
|
||||
py::function hook_;
|
||||
std::string bprop_cls_name_;
|
||||
std::vector<Signature> signatures_;
|
||||
static std::map<std::string, py::object> hook_grad_;
|
||||
};
|
||||
|
|
|
@ -209,4 +209,4 @@ def test_user_define_bprop_check_number():
|
|||
grad_net = GradNet(net)
|
||||
with pytest.raises(TypeError) as ex:
|
||||
ret = grad_net(x, y, sens)
|
||||
assert "For user define net bprop, the gradients number: 1 is not equal to the args number: 2." in str(ex.value)
|
||||
assert "For user defined bprop of net" in str(ex.value)
|
||||
|
|
Loading…
Reference in New Issue