forked from mindspore-Ecosystem/mindspore
add_st_cases_for_pynative_hook
This commit is contained in:
parent
5de0d89eba
commit
4a8b220879
|
@ -1540,6 +1540,9 @@ class Cell(Cell_):
|
||||||
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
||||||
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' should be python "
|
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' should be python "
|
||||||
f"function, but got {type(hook_fn)}.")
|
f"function, but got {type(hook_fn)}.")
|
||||||
|
if hook_fn.__code__.co_name == "staging_specialize":
|
||||||
|
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
|
||||||
|
|
||||||
self._enable_forward_pre_hook = True
|
self._enable_forward_pre_hook = True
|
||||||
_pynative_executor.set_hook_changed(self)
|
_pynative_executor.set_hook_changed(self)
|
||||||
if not hasattr(self, '_forward_pre_hook_key'):
|
if not hasattr(self, '_forward_pre_hook_key'):
|
||||||
|
@ -1636,6 +1639,9 @@ class Cell(Cell_):
|
||||||
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
||||||
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' should be python "
|
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' should be python "
|
||||||
f"function, but got {type(hook_fn)}.")
|
f"function, but got {type(hook_fn)}.")
|
||||||
|
if hook_fn.__code__.co_name == "staging_specialize":
|
||||||
|
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@ms_function' is not supported.")
|
||||||
|
|
||||||
self._enable_forward_hook = True
|
self._enable_forward_hook = True
|
||||||
_pynative_executor.set_hook_changed(self)
|
_pynative_executor.set_hook_changed(self)
|
||||||
if not hasattr(self, '_forward_hook_key'):
|
if not hasattr(self, '_forward_hook_key'):
|
||||||
|
|
|
@ -0,0 +1,238 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.ops import GradOperation
|
||||||
|
from mindspore.common import ParameterTuple
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pre_hook_fn_add(cell_id, inp):
|
||||||
|
x = inp[0] + inp[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pre_hook_fn_mul(cell_id, inp):
|
||||||
|
x = inp[0] * inp[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def forward_hook_fn_relu(cell_id, inp, outp):
|
||||||
|
out = nn.ReLU()(outp)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def forward_hook_fn_add(cell_id, inp, outp):
|
||||||
|
out = outp + outp
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def backward_hook_fn(cell_id, grad_inp, grad_outp):
|
||||||
|
return Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def backward_hook_fn2(cell_id, grad_inp, grad_outp):
|
||||||
|
return Tensor(np.ones([1]).astype(np.float32) * 2), Tensor(np.ones([1]).astype(np.float32) * 3)
|
||||||
|
|
||||||
|
|
||||||
|
def backward_hook_fn3(cell_id, grad_inp, grad_outp):
|
||||||
|
return Tensor(np.ones([1]).astype(np.float32) * 5), Tensor(np.ones([1]).astype(np.float32) * 6)
|
||||||
|
|
||||||
|
|
||||||
|
def backward_hook_fn4(cell_id, grad_inp, grad_outp):
|
||||||
|
return Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 10)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.mul = nn.MatMul()
|
||||||
|
self.handle = self.mul.register_backward_hook(backward_hook_fn)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
x = self.mul(x, y)
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SingleNet, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CmpNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CmpNet, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CmpNetPreHook(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CmpNetPreHook, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CmpNetFWHook(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CmpNetFWHook, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_backward_hook():
|
||||||
|
"""
|
||||||
|
Feature: PyNative hook function.
|
||||||
|
Description: Test PyNative backward hook function.
|
||||||
|
Expectation: The calculation result is correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
input_x = Tensor(np.ones([1]).astype(np.float32))
|
||||||
|
input_y = Tensor(np.ones([1]).astype(np.float32))
|
||||||
|
grad_op = GradOperation(get_all=True, get_by_list=False, sens_param=False)
|
||||||
|
# case 1: register hook function in __init__ function.
|
||||||
|
net = Net()
|
||||||
|
grad = grad_op(net)(input_x, input_y)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), input_x.asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1].asnumpy(), input_x.asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case 2: remove hook function by handle.
|
||||||
|
net.handle.remove()
|
||||||
|
net.handle.remove()
|
||||||
|
grad = grad_op(net)(input_x, input_y)
|
||||||
|
assert len(grad) == 2
|
||||||
|
expect_grad = Tensor(np.ones([1]).astype(np.float32) * 2)
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad.asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad.asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case 3: register hook function by handle
|
||||||
|
net = Net()
|
||||||
|
net.mul.register_backward_hook(backward_hook_fn2)
|
||||||
|
handle3 = net.mul.register_backward_hook(backward_hook_fn3)
|
||||||
|
grad = grad_op(net)(input_x, input_y)
|
||||||
|
assert len(grad) == 2
|
||||||
|
expect_gradx = Tensor(np.ones([1]).astype(np.float32) * 5)
|
||||||
|
expect_grady = Tensor(np.ones([1]).astype(np.float32) * 6)
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_gradx.asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grady.asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case 5: remove hook function by handle.
|
||||||
|
handle3.remove()
|
||||||
|
grad = grad_op(net)(input_x, input_y)
|
||||||
|
assert len(grad) == 2
|
||||||
|
expect_gradx = Tensor(np.ones([1]).astype(np.float32) * 2)
|
||||||
|
expect_grady = Tensor(np.ones([1]).astype(np.float32) * 3)
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_gradx.asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grady.asnumpy(), 0.000001, 0.000001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_hook_base_line():
|
||||||
|
"""
|
||||||
|
Feature: PyNative hook function.
|
||||||
|
Description: The base line case for PyNative hook function.
|
||||||
|
Expectation: The calculation result is correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
input_x = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 2)
|
||||||
|
grad_op = GradOperation(get_all=True, get_by_list=True, sens_param=False)
|
||||||
|
# register pre forward hook.
|
||||||
|
net = SingleNet()
|
||||||
|
handle1 = net.conv.register_forward_pre_hook(forward_pre_hook_fn_add)
|
||||||
|
handle2 = net.conv.register_forward_pre_hook(forward_pre_hook_fn_mul)
|
||||||
|
out = net(input_x)
|
||||||
|
cmp_net_pre_hook = CmpNetPreHook()
|
||||||
|
expect_out = cmp_net_pre_hook(input_x)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(input_x)
|
||||||
|
expect_grad = grad_op(cmp_net_pre_hook, ParameterTuple(cmp_net_pre_hook.trainable_params()))(input_x)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][1].asnumpy(), expect_grad[1][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][2].asnumpy(), expect_grad[1][2].asnumpy(), 0.000001, 0.000001)
|
||||||
|
# register forward hook.
|
||||||
|
handle1.remove()
|
||||||
|
handle2.remove()
|
||||||
|
handlea = net.bn.register_forward_hook(forward_hook_fn_relu)
|
||||||
|
handleb = net.bn.register_forward_hook(forward_hook_fn_add)
|
||||||
|
out = net(input_x)
|
||||||
|
cmp_net_fw_hook = CmpNetFWHook()
|
||||||
|
expect_out = cmp_net_fw_hook(input_x)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(input_x)
|
||||||
|
expect_grad = grad_op(cmp_net_fw_hook, ParameterTuple(cmp_net_fw_hook.trainable_params()))(input_x)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][1].asnumpy(), expect_grad[1][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][2].asnumpy(), expect_grad[1][2].asnumpy(), 0.000001, 0.000001)
|
||||||
|
# register backward hook.
|
||||||
|
handlea.remove()
|
||||||
|
handleb.remove()
|
||||||
|
net.conv.register_backward_hook(backward_hook_fn4)
|
||||||
|
out = net(input_x)
|
||||||
|
compare_net = CmpNet()
|
||||||
|
expect_out = compare_net(input_x)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(input_x)
|
||||||
|
expect_grad = grad_op(compare_net, ParameterTuple(compare_net.trainable_params()))(input_x)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
expect_gradx = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 10)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_gradx.asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][1].asnumpy(), expect_grad[1][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][2].asnumpy(), expect_grad[1][2].asnumpy(), 0.000001, 0.000001)
|
|
@ -0,0 +1,495 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.ops import GradOperation
|
||||||
|
from mindspore.common import ParameterTuple
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pre_hook_fn_bn(cell_id, inp):
|
||||||
|
out = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")(inp[0])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pre_hook_fn_add(cell_id, inp):
|
||||||
|
out = inp[0] + inp[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pre_hook_fn_mul(cell_id, inp):
|
||||||
|
out = inp[0] * inp[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pre_hook_fn_multi_relu(cell_id, inp):
|
||||||
|
out = nn.ReLU()(inp[0])
|
||||||
|
return out, inp[1]
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pre_hook_fn_multi_add(cell_id, inp):
|
||||||
|
x = inp[0] + inp[1]
|
||||||
|
y = inp[0] * inp[1]
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
def forward_hook_fn_conv(cell_id, inp, outp):
|
||||||
|
out = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")(outp)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def forward_hook_fn_add(cell_id, inp, outp):
|
||||||
|
out = outp + outp
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def forward_hook_fn_mul(cell_id, inp, outp):
|
||||||
|
out = outp * outp
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def forward_hook_fn_with_ms_func(cell_id, inp, outp):
|
||||||
|
return outp
|
||||||
|
|
||||||
|
|
||||||
|
def backward_hook_fn(cell_id, grad_inp, grad_outp):
|
||||||
|
print("Enter backward hook function.")
|
||||||
|
return grad_outp[0]
|
||||||
|
|
||||||
|
|
||||||
|
def backward_hook_fn_inner(cell_id, grad_inp, grad_outp):
|
||||||
|
print("Enter backward hook function inner.")
|
||||||
|
return grad_outp[0]
|
||||||
|
|
||||||
|
|
||||||
|
class SingleNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SingleNet, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.handle1 = self.conv.register_forward_hook(forward_hook_fn_add)
|
||||||
|
self.handle2 = self.conv.register_forward_pre_hook(forward_pre_hook_fn_add)
|
||||||
|
self.handle3 = self.relu.register_forward_hook(forward_hook_fn_add)
|
||||||
|
self.handle4 = self.relu.register_forward_pre_hook(forward_pre_hook_fn_bn)
|
||||||
|
self.handle1.remove()
|
||||||
|
self.handle1.remove()
|
||||||
|
self.handle2.remove()
|
||||||
|
self.handle2.remove()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleNetInConstruct(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SingleNetInConstruct, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
self.handle1 = self.conv.register_forward_hook(forward_hook_fn_add)
|
||||||
|
self.handle2 = self.conv.register_forward_pre_hook(forward_pre_hook_fn_add)
|
||||||
|
self.handle3 = self.relu.register_forward_hook(forward_hook_fn_add)
|
||||||
|
self.handle4 = self.relu.register_forward_pre_hook(forward_pre_hook_fn_bn)
|
||||||
|
self.handle1.remove()
|
||||||
|
self.handle1.remove()
|
||||||
|
self.handle2.remove()
|
||||||
|
self.handle2.remove()
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleNetMsFuncInner(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SingleNetMsFuncInner, self).__init__()
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.bn.register_forward_pre_hook(forward_pre_hook_fn_add)
|
||||||
|
self.bn.register_forward_pre_hook(forward_pre_hook_fn_mul)
|
||||||
|
self.bn.register_forward_hook(forward_hook_fn_add)
|
||||||
|
self.bn.register_forward_hook(forward_hook_fn_mul)
|
||||||
|
self.bn.register_backward_hook(backward_hook_fn_inner)
|
||||||
|
self.bn.register_backward_hook(backward_hook_fn_inner)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleNetMsFunc(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SingleNetMsFunc, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.inner = SingleNetMsFuncInner()
|
||||||
|
self.inner.register_forward_pre_hook(forward_pre_hook_fn_add)
|
||||||
|
self.inner.register_forward_hook(forward_hook_fn_add)
|
||||||
|
self.inner.register_forward_hook(forward_hook_fn_mul)
|
||||||
|
self.inner.register_backward_hook(backward_hook_fn)
|
||||||
|
self.inner.register_backward_hook(backward_hook_fn)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.inner(x)
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CompareSingleNet1(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CompareSingleNet1, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CompareSingleNet2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CompareSingleNet2, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x + x
|
||||||
|
x = self.bn(x)
|
||||||
|
x = x + x
|
||||||
|
x = self.relu(x)
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CompareSingleNet3(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CompareSingleNet3, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = x * x
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x + x
|
||||||
|
x = x + x
|
||||||
|
x = self.relu(x)
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CompareSingleNet4(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CompareSingleNet4, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CompareSingleNet5(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CompareSingleNet5, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x + x
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = x + x
|
||||||
|
x = x * x
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiNet, self).__init__()
|
||||||
|
self.mul1 = nn.MatMul()
|
||||||
|
self.handle1 = self.mul1.register_forward_pre_hook(forward_pre_hook_fn_multi_add)
|
||||||
|
self.handle2 = self.mul1.register_forward_hook(forward_hook_fn_conv)
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.mul2 = nn.MatMul()
|
||||||
|
self.handle3 = self.mul2.register_forward_pre_hook(forward_pre_hook_fn_multi_relu)
|
||||||
|
self.handle4 = self.mul2.register_forward_hook(forward_hook_fn_add)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
x = self.mul1(x, y)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.mul2(x, x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CompareMultiNet1(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CompareMultiNet1, self).__init__()
|
||||||
|
self.mul = nn.MatMul()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
x = self.mul(x + x, x * y)
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.mul(x, x)
|
||||||
|
x = x + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CompareMultiNet2(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CompareMultiNet2, self).__init__()
|
||||||
|
self.mul = nn.MatMul()
|
||||||
|
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||||
|
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
x = self.mul(x, y)
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.mul(x, x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_forward_hook():
|
||||||
|
"""
|
||||||
|
Feature: PyNative hook function.
|
||||||
|
Description: Test PyNative forward hook function and forward pre hook function with single input.
|
||||||
|
Expectation: The calculation result is correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 3)
|
||||||
|
grad_op = GradOperation(get_all=True, get_by_list=True, sens_param=False)
|
||||||
|
# case 1: calling remove() of handle to remove some hook function.
|
||||||
|
net = SingleNet()
|
||||||
|
out = net(inputs)
|
||||||
|
compare_single_net1 = CompareSingleNet1()
|
||||||
|
expect_out = compare_single_net1(inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(inputs)
|
||||||
|
expect_grad = grad_op(compare_single_net1, ParameterTuple(compare_single_net1.trainable_params()))(inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case 2: register new hook function.
|
||||||
|
handle1 = net.conv.register_forward_pre_hook(forward_pre_hook_fn_add)
|
||||||
|
net.conv.register_forward_pre_hook(forward_pre_hook_fn_mul)
|
||||||
|
net.conv.register_forward_hook(forward_hook_fn_add)
|
||||||
|
net.relu.register_forward_pre_hook(forward_pre_hook_fn_add)
|
||||||
|
handle2 = net.relu.register_forward_hook(forward_hook_fn_mul)
|
||||||
|
out = net(inputs)
|
||||||
|
compare_single_net2 = CompareSingleNet2()
|
||||||
|
expect_out = compare_single_net2(inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(inputs)
|
||||||
|
expect_grad = grad_op(compare_single_net2, ParameterTuple(compare_single_net2.trainable_params()))(inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case 3: remove some hook function.
|
||||||
|
handle1.remove()
|
||||||
|
net.handle4.remove()
|
||||||
|
handle2.remove()
|
||||||
|
out = net(inputs)
|
||||||
|
compare_single_net3 = CompareSingleNet3()
|
||||||
|
expect_out = compare_single_net3(inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(inputs)
|
||||||
|
expect_grad = grad_op(compare_single_net3, ParameterTuple(compare_single_net3.trainable_params()))(inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_forward_hook_multi_inp():
|
||||||
|
"""
|
||||||
|
Feature: PyNative hook function.
|
||||||
|
Description: Test PyNative forward hook function and forward pre hook function with multi input.
|
||||||
|
Expectation: The calculation result is correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 3)
|
||||||
|
grad_op = GradOperation(get_all=True, get_by_list=True, sens_param=False)
|
||||||
|
# case 1: register hook function for multi-input op.
|
||||||
|
multi_net = MultiNet()
|
||||||
|
out = multi_net(inputs, inputs)
|
||||||
|
compare_multi_net1 = CompareMultiNet1()
|
||||||
|
expect_out = compare_multi_net1(inputs, inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(multi_net, ParameterTuple(multi_net.trainable_params()))(inputs, inputs)
|
||||||
|
expect_grad = grad_op(compare_multi_net1, ParameterTuple(compare_multi_net1.trainable_params()))(inputs, inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[0][1].asnumpy(), expect_grad[0][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][1].asnumpy(), expect_grad[1][2].asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case 2: remove some hook function for multi-input op.
|
||||||
|
multi_net.handle1.remove()
|
||||||
|
multi_net.handle3.remove()
|
||||||
|
multi_net.handle4.remove()
|
||||||
|
out = multi_net(inputs, inputs)
|
||||||
|
compare_multi_net2 = CompareMultiNet2()
|
||||||
|
expect_out = compare_multi_net2(inputs, inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(multi_net, ParameterTuple(multi_net.trainable_params()))(inputs, inputs)
|
||||||
|
expect_grad = grad_op(compare_multi_net2, ParameterTuple(compare_multi_net2.trainable_params()))(inputs, inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[0][1].asnumpy(), expect_grad[0][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][1].asnumpy(), expect_grad[1][2].asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case 4: register hook function in construct.
|
||||||
|
net = SingleNetInConstruct()
|
||||||
|
compare_net = CompareSingleNet1()
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(inputs)
|
||||||
|
expect_grad = grad_op(compare_net, ParameterTuple(compare_net.trainable_params()))(inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_forward_hook_exception():
|
||||||
|
"""
|
||||||
|
Feature: PyNative hook function.
|
||||||
|
Description: Test PyNative forward hook function and forward pre hook function in exception case.
|
||||||
|
Expectation: Raises exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
net = SingleNet()
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net.relu.register_forward_pre_hook("Test")
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net.conv.register_forward_pre_hook(forward_hook_fn_with_ms_func)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net.conv.register_forward_hook(forward_hook_fn_with_ms_func)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_forward_hook_with_ms_func():
|
||||||
|
"""
|
||||||
|
Feature: PyNative hook function.
|
||||||
|
Description: Test PyNative forward hook function and forward pre hook function with ms_function.
|
||||||
|
Expectation: The calculation result is correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 3)
|
||||||
|
grad_op = GradOperation(get_all=True, get_by_list=True, sens_param=False)
|
||||||
|
# case: ms_funciton in pynative mode.
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
single_net_msfunc = SingleNetMsFunc()
|
||||||
|
out = single_net_msfunc(inputs)
|
||||||
|
compare_single_net5 = CompareSingleNet5()
|
||||||
|
expect_out = compare_single_net5(inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(single_net_msfunc, ParameterTuple(single_net_msfunc.trainable_params()))(inputs)
|
||||||
|
expect_grad = grad_op(compare_single_net5, ParameterTuple(compare_single_net5.trainable_params()))(inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][1].asnumpy(), expect_grad[1][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][2].asnumpy(), expect_grad[1][2].asnumpy(), 0.000001, 0.000001)
|
||||||
|
# case: ms_funciton in graph mode.
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
out = single_net_msfunc(inputs)
|
||||||
|
compare_single_net1 = CompareSingleNet1()
|
||||||
|
expect_out = compare_single_net1(inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(single_net_msfunc, ParameterTuple(single_net_msfunc.trainable_params()))(inputs)
|
||||||
|
expect_grad = grad_op(compare_single_net1, ParameterTuple(compare_single_net1.trainable_params()))(inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][1].asnumpy(), expect_grad[1][1].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][2].asnumpy(), expect_grad[1][2].asnumpy(), 0.000001, 0.000001)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_forward_hook_in_graph_mode():
|
||||||
|
"""
|
||||||
|
Feature: PyNative hook function.
|
||||||
|
Description: Test PyNative forward hook function and forward pre hook function in graph mode.
|
||||||
|
Expectation: The calculation result is correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
inputs = Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 3)
|
||||||
|
grad_op = GradOperation(get_all=True, get_by_list=True, sens_param=False)
|
||||||
|
net = SingleNet()
|
||||||
|
out = net(inputs)
|
||||||
|
compare_net = CompareSingleNet4()
|
||||||
|
expect_out = compare_net(inputs)
|
||||||
|
assert np.allclose(out.asnumpy(), expect_out.asnumpy(), 0.000001, 0.000001)
|
||||||
|
grad = grad_op(net, ParameterTuple(net.trainable_params()))(inputs)
|
||||||
|
expect_grad = grad_op(compare_net, ParameterTuple(compare_net.trainable_params()))(inputs)
|
||||||
|
assert len(grad) == len(expect_grad)
|
||||||
|
assert np.allclose(grad[0][0].asnumpy(), expect_grad[0][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
Loading…
Reference in New Issue