forked from mindspore-Ecosystem/mindspore
correct_wrong_info_when_using_ms_function_with_bprop
This commit is contained in:
parent
0fb2337d2e
commit
6e59598f99
|
@ -27,7 +27,6 @@ from textwrap import dedent
|
|||
import asttokens
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore import ops
|
||||
|
@ -105,8 +104,6 @@ def get_parse_method_of_class(obj, parse_method=None):
|
|||
method_name = parse_method
|
||||
elif isinstance(obj, nn.Cell):
|
||||
if obj.enable_hook:
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
raise ValueError("The graph mode does not support hook function.")
|
||||
method_name = "_hook_construct"
|
||||
else:
|
||||
method_name = "construct"
|
||||
|
|
|
@ -210,6 +210,11 @@ void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app,
|
|||
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
||||
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
|
||||
auto input = cnode_morph->input(i);
|
||||
// Skip HookBackward op
|
||||
if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) {
|
||||
auto inp_i = input->cast<CNodePtr>();
|
||||
input = inp_i->input(1);
|
||||
}
|
||||
// Backprop sens wrt fvs.
|
||||
if (IsValueNode<FuncGraph>(input)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(input);
|
||||
|
@ -257,6 +262,13 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
std::vector<AdjointPtr> param_adjoints;
|
||||
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
||||
auto node = cnode_morph->input(i);
|
||||
// Skip HookBackward op
|
||||
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
|
||||
auto input_i = node->cast<CNodePtr>();
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
node = input_i->input(1);
|
||||
}
|
||||
AdjointPtr node_adjoint = nullptr;
|
||||
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
|
||||
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
|
||||
|
@ -417,11 +429,19 @@ void DFunctor::MapMorphism() {
|
|||
|
||||
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
|
||||
MapFreeMorphism();
|
||||
// Skip HookBackward when it is the output node.
|
||||
auto output_node = primal_graph_->output();
|
||||
if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) {
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
output_node = output_cnode->input(1);
|
||||
}
|
||||
// Handle morphism from output.
|
||||
(void)MapMorphism(primal_graph_->output());
|
||||
(void)MapMorphism(output_node);
|
||||
|
||||
// Construct K for primal_graph_
|
||||
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
|
||||
// Construct K for primal_graph_.
|
||||
auto output_adjoint = anfnode_to_adjoin_.find(output_node);
|
||||
// Attach dout_ parameter to output_adjoint.
|
||||
output_adjoint->second->AccumulateDout(dout_);
|
||||
|
||||
|
@ -612,7 +632,9 @@ void DFunctor::MapValueObject() {
|
|||
|
||||
AdjointPtr adjoint = nullptr;
|
||||
if (IsValueNode<Primitive>(node)) { // Primitive.
|
||||
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
|
||||
(prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
|
||||
|
|
|
@ -63,6 +63,10 @@ class SpecialOpEliminater : public OptimizerCaller {
|
|||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2374,6 +2374,10 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
|
|||
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
|
||||
|
||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
|
||||
}
|
||||
// Three parameters self, out and dout need to be excluded
|
||||
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||
if (inputs_num > args.size()) {
|
||||
|
|
|
@ -219,6 +219,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
|
|||
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
|
||||
auto iter = hook_grad_.find(cell_id);
|
||||
if (iter != hook_grad_.end()) {
|
||||
py::object code_obj = py::getattr(hook_, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
|
||||
}
|
||||
|
||||
py::tuple convert_args(input_param_nums - 1);
|
||||
py::tuple input_args(input_param_nums - 1);
|
||||
input_args[0] = iter->second;
|
||||
|
@ -243,6 +249,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
|
|||
}
|
||||
|
||||
BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
|
||||
py::object code_obj = py::getattr(hook_, "__code__");
|
||||
py::object co_name = py::getattr(code_obj, "co_name");
|
||||
if (std::string(py::str(co_name)) == "staging_specialize") {
|
||||
MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
|
||||
}
|
||||
|
||||
constexpr size_t grad_output_index = 2;
|
||||
SyncData(py_args[grad_output_index]);
|
||||
py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
def var_hook_function(grad_out):
|
||||
print("grad:", grad_out)
|
||||
|
||||
|
||||
class GraphVarHook(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GraphVarHook, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.hook = P.HookBackward(var_hook_function)
|
||||
|
||||
def construct(self, x):
|
||||
x = x + x
|
||||
x = x * x
|
||||
x = self.hook(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class MsFuncVarHook(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MsFuncVarHook, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.hook = P.HookBackward(var_hook_function)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x):
|
||||
x = x + x
|
||||
x = x * x
|
||||
x = self.hook(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_var_hook_forward():
|
||||
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net1 = MsFuncVarHook()
|
||||
out1 = net1(input_x)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net2 = GraphVarHook()
|
||||
out2 = net2(input_x)
|
||||
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_var_hook_grad():
|
||||
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net1 = MsFuncVarHook()
|
||||
grad_out1 = grad_all(net1)(input_x)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net2 = GraphVarHook()
|
||||
grad_out2 = grad_all(net2)(input_x)
|
||||
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
|
||||
|
||||
|
||||
def cell_hook_function(cell_id, grad_input, grad_output):
|
||||
print("cell id:", cell_id)
|
||||
print("grad input:", grad_input)
|
||||
print("grad output:", grad_output)
|
||||
|
||||
|
||||
class GraphCellHook(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GraphCellHook, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.relu.register_backward_hook(cell_hook_function)
|
||||
|
||||
def construct(self, x):
|
||||
x = x + x
|
||||
x = x * x
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class MsFuncCellHook(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MsFuncCellHook, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.relu.register_backward_hook(cell_hook_function)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x):
|
||||
x = x + x
|
||||
x = x * x
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_hook_forward():
|
||||
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net1 = MsFuncCellHook()
|
||||
out1 = net1(input_x)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net2 = GraphCellHook()
|
||||
out2 = net2(input_x)
|
||||
assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_hook_grad():
|
||||
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net1 = MsFuncCellHook()
|
||||
grad_out1 = grad_all(net1)(input_x)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net2 = GraphCellHook()
|
||||
grad_out2 = grad_all(net2)(input_x)
|
||||
assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class CellBprop(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CellBprop, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x * x + y * y
|
||||
|
||||
@ms_function
|
||||
def bprop(self, x, y, out, dout):
|
||||
return dout, 2 * y
|
||||
|
||||
|
||||
def test_cell_bprop_grad():
|
||||
input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
input_y = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = CellBprop()
|
||||
with pytest.raises(RuntimeError):
|
||||
grad_all(net)(input_x, input_y)
|
||||
|
||||
|
||||
class ConvNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ConvNet, self).__init__()
|
||||
self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class MomentumWithMsFunc(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(MomentumWithMsFunc, self).__init__()
|
||||
self.net = net
|
||||
self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9)
|
||||
|
||||
@ms_function
|
||||
def construct(self, grads):
|
||||
ret = self.optimizer(grads)
|
||||
return ret
|
||||
|
||||
|
||||
def test_ms_func_decorate_forward():
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32))
|
||||
net = ConvNet()
|
||||
grad_out = grad_all(net)(input_x)
|
||||
opt = MomentumWithMsFunc(net)
|
||||
opt(grad_out)
|
Loading…
Reference in New Issue