forked from mindspore-Ecosystem/mindspore
!1687 add data sync before hook function
Merge pull request !1687 from wangqiuliang/add-data-sync-before-hook-function
This commit is contained in:
commit
c4939d9cc8
|
@ -603,6 +603,19 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
|
|||
MS_LOG(DEBUG) << "End";
|
||||
}
|
||||
|
||||
void FinalVM::SyncData(const py::object &arg) {
|
||||
if (py::isinstance<py::tuple>(arg)) {
|
||||
py::tuple arg_list = py::cast<py::tuple>(arg);
|
||||
for (size_t i = 0; i < arg_list.size(); i++) {
|
||||
SyncData(arg_list[i]);
|
||||
}
|
||||
}
|
||||
if (py::isinstance<tensor::Tensor>(arg)) {
|
||||
auto tensor = py::cast<tensor::TensorPtr>(arg);
|
||||
(void)tensor->data_sync();
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
||||
MS_LOG(DEBUG) << "input for operation:";
|
||||
std::size_t args_size = args.size();
|
||||
|
@ -613,15 +626,20 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
|||
MS_LOG(DEBUG) << "arg: " << i << ":";
|
||||
i++;
|
||||
}
|
||||
// Hook operator for execute cell custom bprop function
|
||||
py::object obj;
|
||||
bool is_bprop = prim->HasAttr("bprop");
|
||||
if (is_bprop) {
|
||||
SyncData(py_args);
|
||||
py::function fn_bprop = prim->hook();
|
||||
obj = fn_bprop(*py_args);
|
||||
return obj;
|
||||
}
|
||||
// Sync gradient data from device to host
|
||||
SyncData(py_args[2]);
|
||||
bool is_cell = prim->HasAttr("cell_hook");
|
||||
if (is_cell) {
|
||||
// Hook operator for execute cell hook function
|
||||
std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
if (_hook_grad.find(cell_id) != _hook_grad.end()) {
|
||||
std::size_t hook_args_size = 3;
|
||||
|
@ -640,6 +658,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
|||
obj = py_args[2];
|
||||
}
|
||||
} else {
|
||||
// Hook operator for execute variable hook function
|
||||
py::function fn_hook = prim->hook();
|
||||
obj = fn_hook(py::make_tuple(py_args[2]));
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
|
|
|
@ -115,7 +115,7 @@ class FinalVM {
|
|||
void InstPushPrim(const VectorRef &args);
|
||||
void InstSwitchReturn(const VectorRef &args);
|
||||
void set_insts(const InstSet &value) { insts_ = value; }
|
||||
BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args);
|
||||
BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg);
|
||||
|
||||
protected:
|
||||
BaseRef Ref(int i);
|
||||
|
@ -129,6 +129,7 @@ class FinalVM {
|
|||
void PushStatus(bool is_switch_call);
|
||||
bool PopStatus();
|
||||
void DoJmp(const BaseRef &jmp);
|
||||
void SyncData(const py::object &args);
|
||||
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
|
||||
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ class Cell:
|
|||
if flags:
|
||||
self.add_flags(**flags)
|
||||
self._backward_hook = None
|
||||
self._enable_hook = False
|
||||
self.enable_hook = False
|
||||
self._bprop_debug = False
|
||||
|
||||
@property
|
||||
|
@ -97,10 +97,24 @@ class Cell:
|
|||
|
||||
@property
|
||||
def bprop_debug(self):
|
||||
"""
|
||||
Get whether cell custom bprop debug is enabled.
|
||||
"""
|
||||
return self._bprop_debug
|
||||
|
||||
@bprop_debug.setter
|
||||
def bprop_debug(self, value):
|
||||
"""
|
||||
Set whether to enable cell custom bprop debug.
|
||||
|
||||
Note:
|
||||
When bprop is defined in cell, the bprop function will be executed
|
||||
in python interpreter when bprop debug is true, and will be parsed
|
||||
and add to graph when bprop debug is false.
|
||||
|
||||
Args:
|
||||
value (bool): Specifies whether to enable bprop debug. Default: False.
|
||||
"""
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("'bprop debug' value must be bool type.")
|
||||
self._bprop_debug = value
|
||||
|
@ -755,17 +769,19 @@ class Cell:
|
|||
outputs = self._backward_hook(inputs)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def enable_hook(self):
|
||||
"""Whether the cell register hook function"""
|
||||
return self._enable_hook
|
||||
|
||||
def register_backward_hook(self, fn):
|
||||
"""
|
||||
Set the cell backward hook function.
|
||||
|
||||
Note:
|
||||
fn should be defined as following code shows, `cell_name` is the name of registered cell,
|
||||
`grad_input` is gradient passed to the cell, `grad_output` is the gradient computed and pass to
|
||||
next cell or primitve, which may be modified and return.
|
||||
>>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None
|
||||
|
||||
Args:
|
||||
fn (function): Specifies the hook function with grad as input.
|
||||
|
||||
"""
|
||||
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
|
||||
self._enable_hook = True
|
||||
|
|
|
@ -247,9 +247,11 @@ class HookBackward(PrimitiveWithInfer):
|
|||
Used as tag to hook gradient in intermediate variables.
|
||||
|
||||
Note:
|
||||
The hook function should have one input of gradient of the variable.
|
||||
hook function will be executed in python environment, while callback
|
||||
of InsertGradientOf will be parsed and added to the graph.
|
||||
The hook function should be defined like `hook_fn(grad) -> Tensor or None`,
|
||||
which grad is the gradient passed to the primitive and gradient may be
|
||||
modified and passed to nex primitive. the difference between hook function and
|
||||
callback of InsertGradientOf is that hook function is executed in python
|
||||
environment while callback will be parsed and added to the graph.
|
||||
|
||||
Args:
|
||||
hook_fn (Function): Python function. hook function.
|
||||
|
@ -312,6 +314,8 @@ class Print(PrimitiveWithInfer):
|
|||
|
||||
2. The data of tensor is a scalar type.
|
||||
|
||||
In pynative mode, please use python print function.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports
|
||||
multiple strings and tensors which are separated by ','.
|
||||
|
|
Loading…
Reference in New Issue