!23099 Correct wrong info when using ms_function to decorate bprop

Merge pull request !23099 from JoyLvliang/correct_wrong_info_when_using_ms_function_with_bprop
This commit is contained in:
i-robot 2021-09-16 08:46:41 +00:00 committed by Gitee
commit a5b793463a
7 changed files with 279 additions and 7 deletions

View File

@ -27,7 +27,6 @@ from textwrap import dedent
import asttokens
from mindspore import Tensor
from mindspore import context
from mindspore import log as logger
from mindspore import nn
from mindspore import ops
@ -105,8 +104,6 @@ def get_parse_method_of_class(obj, parse_method=None):
method_name = parse_method
elif isinstance(obj, nn.Cell):
if obj.enable_hook:
if context.get_context("mode") == context.GRAPH_MODE:
raise ValueError("The graph mode does not support hook function.")
method_name = "_hook_construct"
else:
method_name = "construct"

View File

@ -210,6 +210,11 @@ void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app,
for (size_t i = 0; i < cnode_morph->size(); i++) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
auto input = cnode_morph->input(i);
// Skip HookBackward op
if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) {
auto inp_i = input->cast<CNodePtr>();
input = inp_i->input(1);
}
// Backprop sens wrt fvs.
if (IsValueNode<FuncGraph>(input)) {
auto func_graph = GetValueNode<FuncGraphPtr>(input);
@ -257,6 +262,13 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
std::vector<AdjointPtr> param_adjoints;
for (size_t i = 0; i < cnode_morph->size(); i++) {
auto node = cnode_morph->input(i);
// Skip HookBackward op
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
auto input_i = node->cast<CNodePtr>();
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
node = input_i->input(1);
}
AdjointPtr node_adjoint = nullptr;
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
@ -417,11 +429,19 @@ void DFunctor::MapMorphism() {
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
MapFreeMorphism();
// Skip HookBackward when it is the output node.
auto output_node = primal_graph_->output();
if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) {
auto output_cnode = output_node->cast<CNodePtr>();
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
output_node = output_cnode->input(1);
}
// Handle morphism from output.
(void)MapMorphism(primal_graph_->output());
(void)MapMorphism(output_node);
// Construct K for primal_graph_
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
// Construct K for primal_graph_.
auto output_adjoint = anfnode_to_adjoin_.find(output_node);
// Attach dout_ parameter to output_adjoint.
output_adjoint->second->AccumulateDout(dout_);
@ -612,7 +632,9 @@ void DFunctor::MapValueObject() {
AdjointPtr adjoint = nullptr;
if (IsValueNode<Primitive>(node)) { // Primitive.
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
auto prim = GetValueNode<PrimitivePtr>(node);
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
(prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) {
continue;
}
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";

View File

@ -63,6 +63,10 @@ class SpecialOpEliminater : public OptimizerCaller {
for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
}
return new_node;
}
}

View File

@ -2374,6 +2374,10 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
py::object code_obj = py::getattr(bprop_func, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
}
// Three parameters self, out and dout need to be excluded
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
if (inputs_num > args.size()) {

View File

@ -219,6 +219,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
auto iter = hook_grad_.find(cell_id);
if (iter != hook_grad_.end()) {
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
}
py::tuple convert_args(input_param_nums - 1);
py::tuple input_args(input_param_nums - 1);
input_args[0] = iter->second;
@ -243,6 +249,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
}
BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
}
constexpr size_t grad_output_index = 2;
SyncData(py_args[grad_output_index]);
py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));

View File

@ -0,0 +1,157 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore.ops import composite as C
from mindspore import context, Tensor
from mindspore.common.api import ms_function
grad_all = C.GradOperation(get_all=True)
def var_hook_function(grad_out):
print("grad:", grad_out)
class GraphVarHook(nn.Cell):
def __init__(self):
super(GraphVarHook, self).__init__()
self.relu = nn.ReLU()
self.hook = P.HookBackward(var_hook_function)
def construct(self, x):
x = x + x
x = x * x
x = self.hook(x)
x = self.relu(x)
return x
class MsFuncVarHook(nn.Cell):
def __init__(self):
super(MsFuncVarHook, self).__init__()
self.relu = nn.ReLU()
self.hook = P.HookBackward(var_hook_function)
@ms_function
def construct(self, x):
x = x + x
x = x * x
x = self.hook(x)
x = self.relu(x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_var_hook_forward():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncVarHook()
out1 = net1(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphVarHook()
out2 = net2(input_x)
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_var_hook_grad():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncVarHook()
grad_out1 = grad_all(net1)(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphVarHook()
grad_out2 = grad_all(net2)(input_x)
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
def cell_hook_function(cell_id, grad_input, grad_output):
print("cell id:", cell_id)
print("grad input:", grad_input)
print("grad output:", grad_output)
class GraphCellHook(nn.Cell):
def __init__(self):
super(GraphCellHook, self).__init__()
self.relu = nn.ReLU()
self.relu.register_backward_hook(cell_hook_function)
def construct(self, x):
x = x + x
x = x * x
x = self.relu(x)
return x
class MsFuncCellHook(nn.Cell):
def __init__(self):
super(MsFuncCellHook, self).__init__()
self.relu = nn.ReLU()
self.relu.register_backward_hook(cell_hook_function)
@ms_function
def construct(self, x):
x = x + x
x = x * x
x = self.relu(x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cell_hook_forward():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncCellHook()
out1 = net1(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphCellHook()
out2 = net2(input_x)
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cell_hook_grad():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncCellHook()
grad_out1 = grad_all(net1)(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphCellHook()
grad_out2 = grad_all(net2)(input_x)
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)

View File

@ -0,0 +1,76 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore.nn import Momentum
from mindspore import context, Tensor
from mindspore.common.api import ms_function
grad_all = C.GradOperation(get_all=True)
class CellBprop(nn.Cell):
def __init__(self):
super(CellBprop, self).__init__()
def construct(self, x, y):
return 2 * x * x + y * y
@ms_function
def bprop(self, x, y, out, dout):
return dout, 2 * y
def test_cell_bprop_grad():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
input_y = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net = CellBprop()
with pytest.raises(RuntimeError):
grad_all(net)(input_x, input_y)
class ConvNet(nn.Cell):
def __init__(self):
super(ConvNet, self).__init__()
self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
def construct(self, x):
out = self.conv(x)
return out
class MomentumWithMsFunc(nn.Cell):
def __init__(self, net):
super(MomentumWithMsFunc, self).__init__()
self.net = net
self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9)
@ms_function
def construct(self, grads):
ret = self.optimizer(grads)
return ret
def test_ms_func_decorate_forward():
context.set_context(mode=context.PYNATIVE_MODE)
input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32))
net = ConvNet()
grad_out = grad_all(net)(input_x)
opt = MomentumWithMsFunc(net)
opt(grad_out)