!8971 Fix bug of non py tensor in cell hook and misjudge dynamic cell in pynative

From: @joylvliang
Reviewed-by: @kisnwang,@kisnwang
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-26 13:46:45 +08:00 committed by Gitee
commit e993863c30
3 changed files with 26 additions and 10 deletions

View File

@ -69,8 +69,9 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
const std::set<std::string> ignore_judge_dynamic_cell = {"Cell mindspore.nn.layer.basic.Dense",
"Cell mindspore.nn.probability.distribution.normal.Normal"};
const std::set<std::string> ignore_judge_dynamic_cell = {
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask"};
const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
parse::NAMED_PRIMITIVE_NAMECONSTANT,
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};

View File

@ -108,6 +108,20 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
return grads;
}
void PrimitivePy::ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const {
MS_EXCEPTION_IF_NULL(convert_args);
if (input_args.size() != (*convert_args).size()) {
MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size()
<< " should be equal to the size of convert_args: " << (*convert_args).size();
}
for (size_t i = 0; i < input_args.size(); ++i) {
(*convert_args)[i] = py::isinstance<tensor::Tensor>(input_args[i])
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i])
: input_args[i];
}
}
void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
if (py::isinstance<py::tuple>(expected_grad_out)) {
if (!py::isinstance<py::tuple>(grad_out)) {
@ -150,12 +164,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
if (is_bprop) {
SyncData(py_args);
py::tuple convert_args(py_args.size());
for (size_t i = 0; i < py_args.size(); i++) {
convert_args[i] = py::isinstance<tensor::Tensor>(py_args[i])
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i])
: py_args[i];
}
ConvertCTensorToPyTensor(py_args, &convert_args);
py::object grads_obj = hook_(*convert_args);
py::tuple grads = check_bprop_out(grads_obj, py_args);
return std::make_shared<PyObjectRef>(grads);
@ -167,10 +176,15 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
auto iter = hook_grad_.find(cell_id);
if (iter != hook_grad_.end()) {
py::tuple convert_args(2);
py::tuple input_args(2);
input_args[0] = iter->second;
input_args[1] = py_args[2];
ConvertCTensorToPyTensor(input_args, &convert_args);
auto hook_args = py::tuple(3);
hook_args[0] = cell_id;
hook_args[1] = py::make_tuple(iter->second);
hook_args[2] = py::make_tuple(py_args[2]);
hook_args[1] = py::make_tuple(convert_args[0]);
hook_args[2] = py::make_tuple(convert_args[1]);
obj = hook_(*hook_args);
if (py::isinstance<py::none>(obj)) {
obj = py_args[2];

View File

@ -69,6 +69,7 @@ class PrimitivePy : public Primitive {
private:
py::function GetComputeFunction() const;
void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const;
void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const;
py::object python_obj_;
py::function hook_;