forked from mindspore-Ecosystem/mindspore
!8971 Fix bug of non py tensor in cell hook and misjudge dynamic cell in pynative
From: @joylvliang Reviewed-by: @kisnwang,@kisnwang Signed-off-by:
This commit is contained in:
commit
e993863c30
|
@ -69,8 +69,9 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
|||
|
||||
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
|
||||
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
|
||||
const std::set<std::string> ignore_judge_dynamic_cell = {"Cell mindspore.nn.layer.basic.Dense",
|
||||
"Cell mindspore.nn.probability.distribution.normal.Normal"};
|
||||
const std::set<std::string> ignore_judge_dynamic_cell = {
|
||||
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
|
||||
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask"};
|
||||
const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
|
||||
parse::NAMED_PRIMITIVE_NAMECONSTANT,
|
||||
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
|
||||
|
|
|
@ -108,6 +108,20 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|||
return grads;
|
||||
}
|
||||
|
||||
void PrimitivePy::ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const {
|
||||
MS_EXCEPTION_IF_NULL(convert_args);
|
||||
if (input_args.size() != (*convert_args).size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size()
|
||||
<< " should be equal to the size of convert_args: " << (*convert_args).size();
|
||||
}
|
||||
for (size_t i = 0; i < input_args.size(); ++i) {
|
||||
(*convert_args)[i] = py::isinstance<tensor::Tensor>(input_args[i])
|
||||
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
|
||||
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i])
|
||||
: input_args[i];
|
||||
}
|
||||
}
|
||||
|
||||
void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
|
||||
if (py::isinstance<py::tuple>(expected_grad_out)) {
|
||||
if (!py::isinstance<py::tuple>(grad_out)) {
|
||||
|
@ -150,12 +164,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|||
if (is_bprop) {
|
||||
SyncData(py_args);
|
||||
py::tuple convert_args(py_args.size());
|
||||
for (size_t i = 0; i < py_args.size(); i++) {
|
||||
convert_args[i] = py::isinstance<tensor::Tensor>(py_args[i])
|
||||
? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
|
||||
parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, py_args[i])
|
||||
: py_args[i];
|
||||
}
|
||||
ConvertCTensorToPyTensor(py_args, &convert_args);
|
||||
py::object grads_obj = hook_(*convert_args);
|
||||
py::tuple grads = check_bprop_out(grads_obj, py_args);
|
||||
return std::make_shared<PyObjectRef>(grads);
|
||||
|
@ -167,10 +176,15 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|||
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
|
||||
auto iter = hook_grad_.find(cell_id);
|
||||
if (iter != hook_grad_.end()) {
|
||||
py::tuple convert_args(2);
|
||||
py::tuple input_args(2);
|
||||
input_args[0] = iter->second;
|
||||
input_args[1] = py_args[2];
|
||||
ConvertCTensorToPyTensor(input_args, &convert_args);
|
||||
auto hook_args = py::tuple(3);
|
||||
hook_args[0] = cell_id;
|
||||
hook_args[1] = py::make_tuple(iter->second);
|
||||
hook_args[2] = py::make_tuple(py_args[2]);
|
||||
hook_args[1] = py::make_tuple(convert_args[0]);
|
||||
hook_args[2] = py::make_tuple(convert_args[1]);
|
||||
obj = hook_(*hook_args);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
|
|
|
@ -69,6 +69,7 @@ class PrimitivePy : public Primitive {
|
|||
|
||||
private:
|
||||
py::function GetComputeFunction() const;
|
||||
void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const;
|
||||
void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const;
|
||||
py::object python_obj_;
|
||||
py::function hook_;
|
||||
|
|
Loading…
Reference in New Issue