forked from mindspore-Ecosystem/mindspore
!23099 Correct wrong info when using ms_function to decorate bprop
Merge pull request !23099 from JoyLvliang/correct_wrong_info_when_using_ms_function_with_bprop
This commit is contained in:
commit
a5b793463a
|
@ -27,7 +27,6 @@ from textwrap import dedent
|
||||||
import asttokens
|
import asttokens
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import context
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore import ops
|
from mindspore import ops
|
||||||
|
@ -105,8 +104,6 @@ def get_parse_method_of_class(obj, parse_method=None):
|
||||||
method_name = parse_method
|
method_name = parse_method
|
||||||
elif isinstance(obj, nn.Cell):
|
elif isinstance(obj, nn.Cell):
|
||||||
if obj.enable_hook:
|
if obj.enable_hook:
|
||||||
if context.get_context("mode") == context.GRAPH_MODE:
|
|
||||||
raise ValueError("The graph mode does not support hook function.")
|
|
||||||
method_name = "_hook_construct"
|
method_name = "_hook_construct"
|
||||||
else:
|
else:
|
||||||
method_name = "construct"
|
method_name = "construct"
|
||||||
|
|
|
@ -210,6 +210,11 @@ void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app,
|
||||||
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
||||||
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
|
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
|
||||||
auto input = cnode_morph->input(i);
|
auto input = cnode_morph->input(i);
|
||||||
|
// Skip HookBackward op
|
||||||
|
if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) {
|
||||||
|
auto inp_i = input->cast<CNodePtr>();
|
||||||
|
input = inp_i->input(1);
|
||||||
|
}
|
||||||
// Backprop sens wrt fvs.
|
// Backprop sens wrt fvs.
|
||||||
if (IsValueNode<FuncGraph>(input)) {
|
if (IsValueNode<FuncGraph>(input)) {
|
||||||
auto func_graph = GetValueNode<FuncGraphPtr>(input);
|
auto func_graph = GetValueNode<FuncGraphPtr>(input);
|
||||||
|
@ -257,6 +262,13 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
||||||
std::vector<AdjointPtr> param_adjoints;
|
std::vector<AdjointPtr> param_adjoints;
|
||||||
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
||||||
auto node = cnode_morph->input(i);
|
auto node = cnode_morph->input(i);
|
||||||
|
// Skip HookBackward op
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
|
||||||
|
auto input_i = node->cast<CNodePtr>();
|
||||||
|
MS_LOG(WARNING)
|
||||||
|
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||||
|
node = input_i->input(1);
|
||||||
|
}
|
||||||
AdjointPtr node_adjoint = nullptr;
|
AdjointPtr node_adjoint = nullptr;
|
||||||
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
|
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
|
||||||
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
|
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
|
||||||
|
@ -417,11 +429,19 @@ void DFunctor::MapMorphism() {
|
||||||
|
|
||||||
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
|
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
|
||||||
MapFreeMorphism();
|
MapFreeMorphism();
|
||||||
|
// Skip HookBackward when it is the output node.
|
||||||
|
auto output_node = primal_graph_->output();
|
||||||
|
if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) {
|
||||||
|
auto output_cnode = output_node->cast<CNodePtr>();
|
||||||
|
MS_LOG(WARNING)
|
||||||
|
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||||
|
output_node = output_cnode->input(1);
|
||||||
|
}
|
||||||
// Handle morphism from output.
|
// Handle morphism from output.
|
||||||
(void)MapMorphism(primal_graph_->output());
|
(void)MapMorphism(output_node);
|
||||||
|
|
||||||
// Construct K for primal_graph_
|
// Construct K for primal_graph_.
|
||||||
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
|
auto output_adjoint = anfnode_to_adjoin_.find(output_node);
|
||||||
// Attach dout_ parameter to output_adjoint.
|
// Attach dout_ parameter to output_adjoint.
|
||||||
output_adjoint->second->AccumulateDout(dout_);
|
output_adjoint->second->AccumulateDout(dout_);
|
||||||
|
|
||||||
|
@ -612,7 +632,9 @@ void DFunctor::MapValueObject() {
|
||||||
|
|
||||||
AdjointPtr adjoint = nullptr;
|
AdjointPtr adjoint = nullptr;
|
||||||
if (IsValueNode<Primitive>(node)) { // Primitive.
|
if (IsValueNode<Primitive>(node)) { // Primitive.
|
||||||
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
|
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||||
|
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
|
||||||
|
(prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
|
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
|
||||||
|
|
|
@ -63,6 +63,10 @@ class SpecialOpEliminater : public OptimizerCaller {
|
||||||
for (auto &eliminater : eliminaters_) {
|
for (auto &eliminater : eliminaters_) {
|
||||||
new_node = (*eliminater)(optimizer, node);
|
new_node = (*eliminater)(optimizer, node);
|
||||||
if (new_node != nullptr) {
|
if (new_node != nullptr) {
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
|
||||||
|
MS_LOG(WARNING)
|
||||||
|
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||||
|
}
|
||||||
return new_node;
|
return new_node;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2374,6 +2374,10 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
|
||||||
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
|
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
|
||||||
|
|
||||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||||
|
py::object co_name = py::getattr(code_obj, "co_name");
|
||||||
|
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||||
|
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
|
||||||
|
}
|
||||||
// Three parameters self, out and dout need to be excluded
|
// Three parameters self, out and dout need to be excluded
|
||||||
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
|
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||||
if (inputs_num > args.size()) {
|
if (inputs_num > args.size()) {
|
||||||
|
|
|
@ -219,6 +219,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
|
||||||
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
|
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
|
||||||
auto iter = hook_grad_.find(cell_id);
|
auto iter = hook_grad_.find(cell_id);
|
||||||
if (iter != hook_grad_.end()) {
|
if (iter != hook_grad_.end()) {
|
||||||
|
py::object code_obj = py::getattr(hook_, "__code__");
|
||||||
|
py::object co_name = py::getattr(code_obj, "co_name");
|
||||||
|
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||||
|
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
|
||||||
|
}
|
||||||
|
|
||||||
py::tuple convert_args(input_param_nums - 1);
|
py::tuple convert_args(input_param_nums - 1);
|
||||||
py::tuple input_args(input_param_nums - 1);
|
py::tuple input_args(input_param_nums - 1);
|
||||||
input_args[0] = iter->second;
|
input_args[0] = iter->second;
|
||||||
|
@ -243,6 +249,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
|
BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
|
||||||
|
py::object code_obj = py::getattr(hook_, "__code__");
|
||||||
|
py::object co_name = py::getattr(code_obj, "co_name");
|
||||||
|
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||||
|
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
|
||||||
|
}
|
||||||
|
|
||||||
constexpr size_t grad_output_index = 2;
|
constexpr size_t grad_output_index = 2;
|
||||||
SyncData(py_args[grad_output_index]);
|
SyncData(py_args[grad_output_index]);
|
||||||
py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));
|
py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));
|
||||||
|
|
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
|
||||||
|
grad_all = C.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
|
||||||
|
def var_hook_function(grad_out):
|
||||||
|
print("grad:", grad_out)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphVarHook(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(GraphVarHook, self).__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.hook = P.HookBackward(var_hook_function)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
x = self.hook(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MsFuncVarHook(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MsFuncVarHook, self).__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.hook = P.HookBackward(var_hook_function)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, x):
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
x = self.hook(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_var_hook_forward():
|
||||||
|
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
net1 = MsFuncVarHook()
|
||||||
|
out1 = net1(input_x)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net2 = GraphVarHook()
|
||||||
|
out2 = net2(input_x)
|
||||||
|
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_var_hook_grad():
|
||||||
|
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
net1 = MsFuncVarHook()
|
||||||
|
grad_out1 = grad_all(net1)(input_x)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net2 = GraphVarHook()
|
||||||
|
grad_out2 = grad_all(net2)(input_x)
|
||||||
|
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
|
||||||
|
|
||||||
|
|
||||||
|
def cell_hook_function(cell_id, grad_input, grad_output):
|
||||||
|
print("cell id:", cell_id)
|
||||||
|
print("grad input:", grad_input)
|
||||||
|
print("grad output:", grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphCellHook(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(GraphCellHook, self).__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.relu.register_backward_hook(cell_hook_function)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MsFuncCellHook(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MsFuncCellHook, self).__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.relu.register_backward_hook(cell_hook_function)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, x):
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cell_hook_forward():
|
||||||
|
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
net1 = MsFuncCellHook()
|
||||||
|
out1 = net1(input_x)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net2 = GraphCellHook()
|
||||||
|
out2 = net2(input_x)
|
||||||
|
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_cell_hook_grad():
|
||||||
|
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
net1 = MsFuncCellHook()
|
||||||
|
grad_out1 = grad_all(net1)(input_x)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net2 = GraphCellHook()
|
||||||
|
grad_out2 = grad_all(net2)(input_x)
|
||||||
|
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.nn import Momentum
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
|
||||||
|
grad_all = C.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CellBprop(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CellBprop, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return 2 * x * x + y * y
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def bprop(self, x, y, out, dout):
|
||||||
|
return dout, 2 * y
|
||||||
|
|
||||||
|
|
||||||
|
def test_cell_bprop_grad():
|
||||||
|
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||||
|
input_y = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
net = CellBprop()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
grad_all(net)(input_x, input_y)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(ConvNet, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumWithMsFunc(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(MomentumWithMsFunc, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, grads):
|
||||||
|
ret = self.optimizer(grads)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def test_ms_func_decorate_forward():
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32))
|
||||||
|
net = ConvNet()
|
||||||
|
grad_out = grad_all(net)(input_x)
|
||||||
|
opt = MomentumWithMsFunc(net)
|
||||||
|
opt(grad_out)
|
Loading…
Reference in New Issue