!23099 Correct wrong info when using ms_function to decorate bprop

Merge pull request !23099 from JoyLvliang/correct_wrong_info_when_using_ms_function_with_bprop
This commit is contained in:
i-robot 2021-09-16 08:46:41 +00:00 committed by Gitee
commit a5b793463a
7 changed files with 279 additions and 7 deletions

View File

@ -27,7 +27,6 @@ from textwrap import dedent
import asttokens import asttokens
from mindspore import Tensor from mindspore import Tensor
from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore import nn from mindspore import nn
from mindspore import ops from mindspore import ops
@ -105,8 +104,6 @@ def get_parse_method_of_class(obj, parse_method=None):
method_name = parse_method method_name = parse_method
elif isinstance(obj, nn.Cell): elif isinstance(obj, nn.Cell):
if obj.enable_hook: if obj.enable_hook:
if context.get_context("mode") == context.GRAPH_MODE:
raise ValueError("The graph mode does not support hook function.")
method_name = "_hook_construct" method_name = "_hook_construct"
else: else:
method_name = "construct" method_name = "construct"

View File

@ -210,6 +210,11 @@ void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app,
for (size_t i = 0; i < cnode_morph->size(); i++) { for (size_t i = 0; i < cnode_morph->size(); i++) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))}); auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
auto input = cnode_morph->input(i); auto input = cnode_morph->input(i);
// Skip HookBackward op
if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) {
auto inp_i = input->cast<CNodePtr>();
input = inp_i->input(1);
}
// Backprop sens wrt fvs. // Backprop sens wrt fvs.
if (IsValueNode<FuncGraph>(input)) { if (IsValueNode<FuncGraph>(input)) {
auto func_graph = GetValueNode<FuncGraphPtr>(input); auto func_graph = GetValueNode<FuncGraphPtr>(input);
@ -257,6 +262,13 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
std::vector<AdjointPtr> param_adjoints; std::vector<AdjointPtr> param_adjoints;
for (size_t i = 0; i < cnode_morph->size(); i++) { for (size_t i = 0; i < cnode_morph->size(); i++) {
auto node = cnode_morph->input(i); auto node = cnode_morph->input(i);
// Skip HookBackward op
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
auto input_i = node->cast<CNodePtr>();
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
node = input_i->input(1);
}
AdjointPtr node_adjoint = nullptr; AdjointPtr node_adjoint = nullptr;
auto node_adjoint_iter = anfnode_to_adjoin_.find(node); auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
if (node_adjoint_iter != anfnode_to_adjoin_.end()) { if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
@ -417,11 +429,19 @@ void DFunctor::MapMorphism() {
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
MapFreeMorphism(); MapFreeMorphism();
// Skip HookBackward when it is the output node.
auto output_node = primal_graph_->output();
if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) {
auto output_cnode = output_node->cast<CNodePtr>();
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
output_node = output_cnode->input(1);
}
// Handle morphism from output. // Handle morphism from output.
(void)MapMorphism(primal_graph_->output()); (void)MapMorphism(output_node);
// Construct K for primal_graph_ // Construct K for primal_graph_.
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); auto output_adjoint = anfnode_to_adjoin_.find(output_node);
// Attach dout_ parameter to output_adjoint. // Attach dout_ parameter to output_adjoint.
output_adjoint->second->AccumulateDout(dout_); output_adjoint->second->AccumulateDout(dout_);
@ -612,7 +632,9 @@ void DFunctor::MapValueObject() {
AdjointPtr adjoint = nullptr; AdjointPtr adjoint = nullptr;
if (IsValueNode<Primitive>(node)) { // Primitive. if (IsValueNode<Primitive>(node)) { // Primitive.
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) { auto prim = GetValueNode<PrimitivePtr>(node);
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
(prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) {
continue; continue;
} }
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << "."; MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";

View File

@ -63,6 +63,10 @@ class SpecialOpEliminater : public OptimizerCaller {
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
MS_LOG(WARNING)
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
}
return new_node; return new_node;
} }
} }

View File

@ -2374,6 +2374,10 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true)); (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
py::object code_obj = py::getattr(bprop_func, "__code__"); py::object code_obj = py::getattr(bprop_func, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
}
// Three parameters self, out and dout need to be excluded // Three parameters self, out and dout need to be excluded
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3; const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
if (inputs_num > args.size()) { if (inputs_num > args.size()) {

View File

@ -219,6 +219,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName)); auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
auto iter = hook_grad_.find(cell_id); auto iter = hook_grad_.find(cell_id);
if (iter != hook_grad_.end()) { if (iter != hook_grad_.end()) {
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
}
py::tuple convert_args(input_param_nums - 1); py::tuple convert_args(input_param_nums - 1);
py::tuple input_args(input_param_nums - 1); py::tuple input_args(input_param_nums - 1);
input_args[0] = iter->second; input_args[0] = iter->second;
@ -243,6 +249,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
} }
BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const { BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
py::object code_obj = py::getattr(hook_, "__code__");
py::object co_name = py::getattr(code_obj, "co_name");
if (std::string(py::str(co_name)) == "staging_specialize") {
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
}
constexpr size_t grad_output_index = 2; constexpr size_t grad_output_index = 2;
SyncData(py_args[grad_output_index]); SyncData(py_args[grad_output_index]);
py::object obj = hook_(py::make_tuple(py_args[grad_output_index])); py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));

View File

@ -0,0 +1,157 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore.ops import composite as C
from mindspore import context, Tensor
from mindspore.common.api import ms_function
grad_all = C.GradOperation(get_all=True)
def var_hook_function(grad_out):
print("grad:", grad_out)
class GraphVarHook(nn.Cell):
def __init__(self):
super(GraphVarHook, self).__init__()
self.relu = nn.ReLU()
self.hook = P.HookBackward(var_hook_function)
def construct(self, x):
x = x + x
x = x * x
x = self.hook(x)
x = self.relu(x)
return x
class MsFuncVarHook(nn.Cell):
def __init__(self):
super(MsFuncVarHook, self).__init__()
self.relu = nn.ReLU()
self.hook = P.HookBackward(var_hook_function)
@ms_function
def construct(self, x):
x = x + x
x = x * x
x = self.hook(x)
x = self.relu(x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_var_hook_forward():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncVarHook()
out1 = net1(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphVarHook()
out2 = net2(input_x)
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_var_hook_grad():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncVarHook()
grad_out1 = grad_all(net1)(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphVarHook()
grad_out2 = grad_all(net2)(input_x)
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
def cell_hook_function(cell_id, grad_input, grad_output):
print("cell id:", cell_id)
print("grad input:", grad_input)
print("grad output:", grad_output)
class GraphCellHook(nn.Cell):
def __init__(self):
super(GraphCellHook, self).__init__()
self.relu = nn.ReLU()
self.relu.register_backward_hook(cell_hook_function)
def construct(self, x):
x = x + x
x = x * x
x = self.relu(x)
return x
class MsFuncCellHook(nn.Cell):
def __init__(self):
super(MsFuncCellHook, self).__init__()
self.relu = nn.ReLU()
self.relu.register_backward_hook(cell_hook_function)
@ms_function
def construct(self, x):
x = x + x
x = x * x
x = self.relu(x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cell_hook_forward():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncCellHook()
out1 = net1(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphCellHook()
out2 = net2(input_x)
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cell_hook_grad():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net1 = MsFuncCellHook()
grad_out1 = grad_all(net1)(input_x)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphCellHook()
grad_out2 = grad_all(net2)(input_x)
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)

View File

@ -0,0 +1,76 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore.nn import Momentum
from mindspore import context, Tensor
from mindspore.common.api import ms_function
grad_all = C.GradOperation(get_all=True)
class CellBprop(nn.Cell):
def __init__(self):
super(CellBprop, self).__init__()
def construct(self, x, y):
return 2 * x * x + y * y
@ms_function
def bprop(self, x, y, out, dout):
return dout, 2 * y
def test_cell_bprop_grad():
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
input_y = Tensor(np.random.randn(2, 2).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE)
net = CellBprop()
with pytest.raises(RuntimeError):
grad_all(net)(input_x, input_y)
class ConvNet(nn.Cell):
def __init__(self):
super(ConvNet, self).__init__()
self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
def construct(self, x):
out = self.conv(x)
return out
class MomentumWithMsFunc(nn.Cell):
def __init__(self, net):
super(MomentumWithMsFunc, self).__init__()
self.net = net
self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9)
@ms_function
def construct(self, grads):
ret = self.optimizer(grads)
return ret
def test_ms_func_decorate_forward():
context.set_context(mode=context.PYNATIVE_MODE)
input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32))
net = ConvNet()
grad_out = grad_all(net)(input_x)
opt = MomentumWithMsFunc(net)
opt(grad_out)