forked from OSSInnovation/mindspore
fix gradients issue in pynative
This commit is contained in:
parent
201bcdd9af
commit
577535b387
|
@ -316,7 +316,10 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
|
||||||
}
|
}
|
||||||
op_exec_info->py_primitive = prim;
|
op_exec_info->py_primitive = prim;
|
||||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||||
op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info);
|
auto inst = PynativeExecutor::GetInstance();
|
||||||
|
if (inst->grad_flag()) {
|
||||||
|
op_exec_info->value = inst->GetForwardValue(op_exec_info);
|
||||||
|
}
|
||||||
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
||||||
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1029,15 +1032,24 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
|
||||||
AnfNodePtr para_node = nullptr;
|
AnfNodePtr para_node = nullptr;
|
||||||
if (graph_info_map_[df_builder_].param_map.count(param_id)) {
|
if (graph_info_map_[df_builder_].param_map.count(param_id)) {
|
||||||
para_node = graph_info_map_[df_builder_].param_map[param_id];
|
para_node = graph_info_map_[df_builder_].param_map[param_id];
|
||||||
|
} else {
|
||||||
AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
|
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name");
|
||||||
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
|
if (py::isinstance<py::none>(name_attr)) {
|
||||||
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
|
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
||||||
AnfNodePtr ref_key_node = NewValueNode(refkey);
|
}
|
||||||
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
|
auto param_name = py::cast<std::string>(name_attr);
|
||||||
|
auto free_param = df_builder_->add_parameter();
|
||||||
w_args.push_back(ref_node);
|
free_param->set_name(param_name);
|
||||||
|
free_param->set_default_param(py::cast<tensor::TensorPtr>(param));
|
||||||
|
free_param->debug_info()->set_name(param_name);
|
||||||
|
para_node = free_param;
|
||||||
}
|
}
|
||||||
|
AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
|
||||||
|
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
|
||||||
|
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
|
||||||
|
AnfNodePtr ref_key_node = NewValueNode(refkey);
|
||||||
|
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
|
||||||
|
w_args.push_back(ref_node);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "training not paramter_tuple";
|
MS_LOG(DEBUG) << "training not paramter_tuple";
|
||||||
|
|
|
@ -185,6 +185,9 @@ class Tensor(Tensor_):
|
||||||
def __imod__(self, other):
|
def __imod__(self, other):
|
||||||
return self.__mod__(other)
|
return self.__mod__(other)
|
||||||
|
|
||||||
|
def __rmod__(self, other):
|
||||||
|
return tensor_operator_registry.get('__mod__')(other, self)
|
||||||
|
|
||||||
def __pow__(self, other):
|
def __pow__(self, other):
|
||||||
return tensor_operator_registry.get('__pow__')(self, other)
|
return tensor_operator_registry.get('__pow__')(self, other)
|
||||||
|
|
||||||
|
@ -194,6 +197,9 @@ class Tensor(Tensor_):
|
||||||
def __ifloordiv__(self, other):
|
def __ifloordiv__(self, other):
|
||||||
return self.__floordiv__(other)
|
return self.__floordiv__(other)
|
||||||
|
|
||||||
|
def __rfloordiv__(self, other):
|
||||||
|
return tensor_operator_registry.get('__floordiv__')(other, self)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.dtype == mstype.type_none:
|
if self.dtype == mstype.type_none:
|
||||||
return "Unknown Tensor type!"
|
return "Unknown Tensor type!"
|
||||||
|
|
|
@ -472,3 +472,7 @@ def test_tensor_operation():
|
||||||
assert np.all(x.asnumpy() == np.ones((3, 3)))
|
assert np.all(x.asnumpy() == np.ones((3, 3)))
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
res = x * (2, 3)
|
res = x * (2, 3)
|
||||||
|
res = 5 % x
|
||||||
|
assert np.all(x.asnumpy() == np.ones((3, 3)))
|
||||||
|
res = 5 // x
|
||||||
|
assert np.all(x.asnumpy() == np.ones((3, 3)))
|
||||||
|
|
Loading…
Reference in New Issue