From ea0d53c27c27cae571b502fb59c3107f069542c4 Mon Sep 17 00:00:00 2001 From: lvliang Date: Tue, 24 Nov 2020 21:04:27 +0800 Subject: [PATCH] fix-bug-of-zero-in-tensor-shape-and-non-py-tensor-in-cell-hook-in-pynative --- mindspore/ccsrc/pipeline/pynative/base.h | 5 ++-- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 30 ++++++++++++++----- mindspore/ccsrc/pybind_api/ir/primitive_py.h | 1 + 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index 8203d46ee2a..81fd4b6aa07 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -69,8 +69,9 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args); const std::set ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; const std::set force_infer_prim = {"TopK", "DropoutGenMask"}; -const std::set ignore_judge_dynamic_cell = {"Cell mindspore.nn.layer.basic.Dense", - "Cell mindspore.nn.probability.distribution.normal.Normal"}; +const std::set ignore_judge_dynamic_cell = { + "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", + "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask"}; const std::set unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, parse::NAMED_PRIMITIVE_NAMECONSTANT, parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 08804dc3054..64c5d7e494c 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -108,6 +108,20 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) return grads; } +void PrimitivePy::ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const { + MS_EXCEPTION_IF_NULL(convert_args); + if (input_args.size() != (*convert_args).size()) { + MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size() + << " should be equal to the size of convert_args: " << (*convert_args).size(); + } + for (size_t i = 0; i < input_args.size(); ++i) { + (*convert_args)[i] = py::isinstance(input_args[i]) + ? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, + parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i]) + : input_args[i]; + } +} + void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const { if (py::isinstance(expected_grad_out)) { if (!py::isinstance(grad_out)) { @@ -150,12 +164,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { if (is_bprop) { SyncData(py_args); py::tuple convert_args(py_args.size()); - for (size_t i = 0; i < py_args.size(); i++) { - convert_args[i] = py::isinstance(py_args[i]) - ? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, - parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i]) - : py_args[i]; - } + ConvertCTensorToPyTensor(py_args, &convert_args); py::object grads_obj = hook_(*convert_args); py::tuple grads = check_bprop_out(grads_obj, py_args); return std::make_shared(grads); @@ -167,10 +176,15 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); auto iter = hook_grad_.find(cell_id); if (iter != hook_grad_.end()) { + py::tuple convert_args(2); + py::tuple input_args(2); + input_args[0] = iter->second; + input_args[1] = py_args[2]; + ConvertCTensorToPyTensor(input_args, &convert_args); auto hook_args = py::tuple(3); hook_args[0] = cell_id; - hook_args[1] = py::make_tuple(iter->second); - hook_args[2] = py::make_tuple(py_args[2]); + hook_args[1] = py::make_tuple(convert_args[0]); + hook_args[2] = py::make_tuple(convert_args[1]); obj = hook_(*hook_args); if (py::isinstance(obj)) { obj = py_args[2]; diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index efd27abcccf..58e153de504 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -69,6 +69,7 @@ class PrimitivePy : public Primitive { private: py::function GetComputeFunction() const; + void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const; void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const; py::object python_obj_; py::function hook_;