forked from mindspore-Ecosystem/mindspore
correct_check_pynative_hook
This commit is contained in:
parent
0a52e1723c
commit
5b73786297
|
@ -2642,9 +2642,9 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
|
|||
}
|
||||
// Three parameters self, out and dout need to be excluded
|
||||
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||
if (inputs_num > args.size()) {
|
||||
MS_EXCEPTION(TypeError) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
|
||||
<< args.size() << "]";
|
||||
if (inputs_num != args.size()) {
|
||||
MS_EXCEPTION(TypeError) << "Size of bprop func inputs[" << inputs_num
|
||||
<< "] is not equal to the size of cell inputs[" << args.size() << "]";
|
||||
}
|
||||
|
||||
py::list cell_inputs;
|
||||
|
|
|
@ -613,49 +613,33 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
|
|||
}
|
||||
}
|
||||
|
||||
void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
|
||||
void ConvertPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
for (const auto &input_object : tuple_inputs) {
|
||||
if (!py::isinstance<tensor::Tensor>(input_object)) {
|
||||
MS_LOG(EXCEPTION) << "The input object is not a tensor!";
|
||||
tensor::TensorPtr tensor_ptr = nullptr;
|
||||
if (py::isinstance<tensor::Tensor>(input_object)) {
|
||||
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
|
||||
} else if (py::isinstance<py::float_>(input_object)) {
|
||||
double input_value = py::cast<py::float_>(input_object);
|
||||
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
|
||||
} else if (py::isinstance<py::int_>(input_object)) {
|
||||
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
|
||||
} else if (py::isinstance<py::list>(input_object)) {
|
||||
auto list_inputs = py::cast<py::list>(input_object);
|
||||
for (size_t i = 0; i < list_inputs.size(); ++i) {
|
||||
ConvertPyObjectToTensor(list_inputs[i], tensors);
|
||||
}
|
||||
auto tensor = py::cast<tensor::TensorPtr>(input_object);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
(void)tensors->emplace_back(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
if (!input_value->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
|
||||
}
|
||||
|
||||
auto value_tuple = input_value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
(void)tensors->emplace_back(tensor_ptr);
|
||||
}
|
||||
|
||||
void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
if (!py::isinstance<py::tuple>(input_object)) {
|
||||
MS_LOG(EXCEPTION) << "The input should be a tuple!";
|
||||
}
|
||||
|
||||
auto inputs = py::cast<py::tuple>(input_object);
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
|
||||
}
|
||||
|
||||
if (py::isinstance<tensor::Tensor>(inputs[0])) {
|
||||
PlantTensorTupleToVector(inputs, tensors);
|
||||
return;
|
||||
} else if (py::isinstance<py::tuple>(input_object)) {
|
||||
auto tuple_inputs = py::cast<py::tuple>(input_object);
|
||||
for (size_t i = 0; i < tuple_inputs.size(); ++i) {
|
||||
ConvertPyObjectToTensor(tuple_inputs[i], tensors);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
ConvertValueTupleToTensor(input_object, tensors);
|
||||
MS_EXCEPTION(TypeError) << "Unreasonable data type: " << input_object.get_type() << ".";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
tensors->emplace_back(tensor_ptr);
|
||||
}
|
||||
|
||||
void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, const KernelGraphPtr &graph,
|
||||
|
@ -698,7 +682,7 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co
|
|||
PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
|
||||
auto out_py_tuple = py_ref.object_;
|
||||
std::vector<tensor::TensorPtr> output_tensors;
|
||||
ConvertMultiPyObjectToTensor(out_py_tuple, &output_tensors);
|
||||
ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
|
||||
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
|
||||
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
|
||||
}
|
||||
|
|
|
@ -195,6 +195,7 @@ class ForwardValueAndGrad(Cell):
|
|||
Args:
|
||||
network (Cell): The training network.
|
||||
weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
|
||||
Default: None.
|
||||
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
||||
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
|
||||
If get_all and get_by_list are both False, get the gradient with respect to first input.
|
||||
|
|
|
@ -353,9 +353,11 @@ class HookBackward(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import ops
|
||||
>>> from mindspore.ops import GradOperation
|
||||
>>> context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
>>> def hook_fn(grad_out):
|
||||
... print(grad_out)
|
||||
...
|
||||
|
@ -371,6 +373,7 @@ class HookBackward(PrimitiveWithInfer):
|
|||
... return grad_all(hook_test)(x, y)
|
||||
...
|
||||
>>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
|
||||
(Tensor(shape=[], dtype=Float32, value= 2),)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue