forked from mindspore-Ecosystem/mindspore
!1427 fix check bprop attr error
Merge pull request !1427 from panyifeng/fix_check_bprop_attr_error
This commit is contained in:
commit
0b191615a9
|
@ -32,6 +32,7 @@
|
|||
#include "operator/composite/composite.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "debug/info.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
|
@ -181,10 +182,19 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
|
|||
}
|
||||
|
||||
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool check_bprop_flag = context->check_bprop_flag();
|
||||
// Skip checking if check_bprop not set
|
||||
if (!check_bprop_flag) {
|
||||
return;
|
||||
}
|
||||
|
||||
// bprop_fg has been checked in caller
|
||||
auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(check_bprop);
|
||||
check_bprop->set_attr("prim_to_check", std::make_shared<StringImm>(prim_to_check));
|
||||
auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
|
||||
MS_EXCEPTION_IF_NULL(check_bprop_class);
|
||||
auto check_bprop =
|
||||
bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
@ -192,7 +202,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
|
|||
AnfNodePtr params = bprop_fg->NewCNode(inputs);
|
||||
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(check_bprop));
|
||||
inputs.push_back(check_bprop);
|
||||
inputs.push_back(bprop_fg->output());
|
||||
inputs.push_back(params);
|
||||
AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
|
||||
|
|
|
@ -141,7 +141,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
|
||||
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
|
||||
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
|
||||
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.");
|
||||
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
|
||||
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
|
||||
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.");
|
||||
|
||||
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
|
||||
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
|
||||
|
|
|
@ -140,6 +140,8 @@ class MsContext {
|
|||
|
||||
void set_profiling_options(const std::string &options) { profiling_options_ = options; }
|
||||
std::string profiling_options() const { return profiling_options_; }
|
||||
bool check_bprop_flag() const { return check_bprop_flag_; }
|
||||
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
|
||||
|
||||
private:
|
||||
MsContext(const std::string &backend_policy, const std::string &target);
|
||||
|
@ -179,6 +181,7 @@ class MsContext {
|
|||
std::thread tdt_print_;
|
||||
bool profiling_mode_;
|
||||
std::string profiling_options_;
|
||||
bool check_bprop_flag_;
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -324,6 +324,13 @@ class _Context:
|
|||
thread_info = self._thread_local_info
|
||||
thread_info.debug_runtime = enable
|
||||
|
||||
@property
|
||||
def check_bprop(self):
|
||||
return self._context_handle.get_check_bprop_flag()
|
||||
|
||||
@check_bprop.setter
|
||||
def check_bprop(self, check_bprop_flag):
|
||||
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
||||
|
||||
def check_input_format(x):
|
||||
import re
|
||||
|
@ -449,7 +456,8 @@ def reset_auto_parallel_context():
|
|||
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
|
||||
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
||||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool)
|
||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||
check_bprop=bool)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Sets context for running environment.
|
||||
|
@ -500,6 +508,7 @@ def set_context(**kwargs):
|
|||
The profiling can choose training_trace, task_trace, training_trace and task_trace combination and
|
||||
separated by colons; single operator can choose op_trace, op_trace cannot be combined with
|
||||
training_trace and task_trace. Default: "training_trace".
|
||||
check_bprop (bool): Whether to check bprop. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
|
|
@ -323,8 +323,9 @@ class CheckBprop(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
def __init__(self, prim_to_check=""):
|
||||
"""init CheckBprop"""
|
||||
self.prim_to_check = prim_to_check
|
||||
|
||||
def infer_shape(self, xshapes, yshapes):
|
||||
tips = f'Bprop of {self.prim_to_check}'
|
||||
|
|
|
@ -353,6 +353,7 @@ class MulAddWithWrongOutputNum(nn.Cell):
|
|||
|
||||
|
||||
def test_grad_mul_add_with_wrong_output_num():
|
||||
context.set_context(check_bprop=True)
|
||||
mul_add = MulAddWithWrongOutputNum()
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(mul_add)(1, 2)
|
||||
|
@ -370,6 +371,7 @@ class MulAddWithWrongOutputType(nn.Cell):
|
|||
|
||||
|
||||
def test_grad_mul_add_with_wrong_output_type():
|
||||
context.set_context(check_bprop=True)
|
||||
mul_add = MulAddWithWrongOutputType()
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
|
||||
|
@ -388,6 +390,7 @@ class MulAddWithWrongOutputShape(nn.Cell):
|
|||
|
||||
|
||||
def test_grad_mul_add_with_wrong_output_shape():
|
||||
context.set_context(check_bprop=True)
|
||||
mul_add = MulAddWithWrongOutputShape()
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
|
||||
|
|
|
@ -893,6 +893,7 @@ def test_grad_if_defer_inline():
|
|||
|
||||
|
||||
def test_bprop_with_wrong_output_num():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputNum(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
|
@ -926,8 +927,8 @@ def test_bprop_with_wrong_output_num():
|
|||
with pytest.raises(TypeError):
|
||||
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
|
||||
|
||||
|
||||
def test_bprop_with_wrong_output_type():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputType(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
|
@ -963,6 +964,7 @@ def test_bprop_with_wrong_output_type():
|
|||
|
||||
|
||||
def test_bprop_with_wrong_output_shape():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputShape(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue