correct_check_pynative_hook

This commit is contained in:
7347157+joylvliang@user.noreply.gitee.com 2022-01-07 17:34:20 +08:00
parent 0a52e1723c
commit 5b73786297
4 changed files with 31 additions and 43 deletions

View File

@ -2642,9 +2642,9 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
}
// Three parameters self, out and dout need to be excluded
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
if (inputs_num > args.size()) {
MS_EXCEPTION(TypeError) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
<< args.size() << "]";
if (inputs_num != args.size()) {
MS_EXCEPTION(TypeError) << "Size of bprop func inputs[" << inputs_num
<< "] is not equal to the size of cell inputs[" << args.size() << "]";
}
py::list cell_inputs;

View File

@ -613,49 +613,33 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
}
}
void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
for (const auto &input_object : tuple_inputs) {
if (!py::isinstance<tensor::Tensor>(input_object)) {
MS_LOG(EXCEPTION) << "The input object is not a tensor!";
tensor::TensorPtr tensor_ptr = nullptr;
if (py::isinstance<tensor::Tensor>(input_object)) {
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
} else if (py::isinstance<py::float_>(input_object)) {
double input_value = py::cast<py::float_>(input_object);
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
} else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
} else if (py::isinstance<py::list>(input_object)) {
auto list_inputs = py::cast<py::list>(input_object);
for (size_t i = 0; i < list_inputs.size(); ++i) {
ConvertPyObjectToTensor(list_inputs[i], tensors);
}
auto tensor = py::cast<tensor::TensorPtr>(input_object);
MS_EXCEPTION_IF_NULL(tensor);
(void)tensors->emplace_back(tensor);
}
}
void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
MS_EXCEPTION_IF_NULL(input_value);
if (!input_value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
}
auto value_tuple = input_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
MS_EXCEPTION_IF_NULL(tensor_ptr);
(void)tensors->emplace_back(tensor_ptr);
}
void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
if (!py::isinstance<py::tuple>(input_object)) {
MS_LOG(EXCEPTION) << "The input should be a tuple!";
}
auto inputs = py::cast<py::tuple>(input_object);
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
}
if (py::isinstance<tensor::Tensor>(inputs[0])) {
PlantTensorTupleToVector(inputs, tensors);
return;
} else if (py::isinstance<py::tuple>(input_object)) {
auto tuple_inputs = py::cast<py::tuple>(input_object);
for (size_t i = 0; i < tuple_inputs.size(); ++i) {
ConvertPyObjectToTensor(tuple_inputs[i], tensors);
}
return;
} else {
ConvertValueTupleToTensor(input_object, tensors);
MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << ".";
}
MS_EXCEPTION_IF_NULL(tensor_ptr);
tensors->emplace_back(tensor_ptr);
}
void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, const KernelGraphPtr &graph,
@ -698,7 +682,7 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co
PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
auto out_py_tuple = py_ref.object_;
std::vector<tensor::TensorPtr> output_tensors;
ConvertMultiPyObjectToTensor(out_py_tuple, &output_tensors);
ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
}

View File

@ -195,6 +195,7 @@ class ForwardValueAndGrad(Cell):
Args:
network (Cell): The training network.
weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
Default: None.
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
If get_all and get_by_list are both False, get the gradient with respect to first input.

View File

@ -353,9 +353,11 @@ class HookBackward(PrimitiveWithInfer):
Examples:
>>> import mindspore
>>> from mindspore import context
>>> from mindspore import Tensor
>>> from mindspore import ops
>>> from mindspore.ops import GradOperation
>>> context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
>>> def hook_fn(grad_out):
... print(grad_out)
...
@ -371,6 +373,7 @@ class HookBackward(PrimitiveWithInfer):
... return grad_all(hook_test)(x, y)
...
>>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
(Tensor(shape=[], dtype=Float32, value= 2),)
>>> print(output)
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
"""